Loading megatron/model/realm_model.py +2 −2 Original line number Diff line number Diff line Loading @@ -52,7 +52,7 @@ class ICTBertModel(MegatronModule): def embed_query(self, query_tokens, query_attention_mask): """Embed a batch of tokens using the query model""" if self.use_query_model: query_types = torch.zeros(query_tokens.shape).type(torch.int64).cuda() query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0) query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types) return query_ict_logits else: Loading @@ -61,7 +61,7 @@ class ICTBertModel(MegatronModule): def embed_block(self, block_tokens, block_attention_mask): """Embed a batch of tokens using the block model""" if self.use_block_model: block_types = torch.zeros(block_tokens.shape).type(torch.int64).cuda() block_types = torch.cuda.LongTensor(*block_tokens.shape).fill_(0) block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types) return block_ict_logits else: Loading pretrain_bert_ict.py +2 −2 Original line number Diff line number Diff line Loading @@ -99,8 +99,8 @@ def forward_step(data_iterator, model): global_batch_size = int(batch_size * data_parallel_size) all_logits_shape = (int(global_batch_size), int(query_logits.shape[1])) all_query_logits = torch.zeros(all_logits_shape).type(query_logits.dtype).cuda() all_block_logits = all_query_logits.clone().cuda() all_query_logits = torch.cuda.FloatTensor(*all_logits_shape).type(query_logits.dtype).fill_(0.0) all_block_logits = all_query_logits.clone() # record this processes' data all_query_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = query_logits Loading Loading
megatron/model/realm_model.py +2 −2 Original line number Diff line number Diff line Loading @@ -52,7 +52,7 @@ class ICTBertModel(MegatronModule): def embed_query(self, query_tokens, query_attention_mask): """Embed a batch of tokens using the query model""" if self.use_query_model: query_types = torch.zeros(query_tokens.shape).type(torch.int64).cuda() query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0) query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types) return query_ict_logits else: Loading @@ -61,7 +61,7 @@ class ICTBertModel(MegatronModule): def embed_block(self, block_tokens, block_attention_mask): """Embed a batch of tokens using the block model""" if self.use_block_model: block_types = torch.zeros(block_tokens.shape).type(torch.int64).cuda() block_types = torch.cuda.LongTensor(*block_tokens.shape).fill_(0) block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types) return block_ict_logits else: Loading
pretrain_bert_ict.py +2 −2 Original line number Diff line number Diff line Loading @@ -99,8 +99,8 @@ def forward_step(data_iterator, model): global_batch_size = int(batch_size * data_parallel_size) all_logits_shape = (int(global_batch_size), int(query_logits.shape[1])) all_query_logits = torch.zeros(all_logits_shape).type(query_logits.dtype).cuda() all_block_logits = all_query_logits.clone().cuda() all_query_logits = torch.cuda.FloatTensor(*all_logits_shape).type(query_logits.dtype).fill_(0.0) all_block_logits = all_query_logits.clone() # record this processes' data all_query_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = query_logits Loading