Commit 76928caa authored by Neel Kant's avatar Neel Kant
Browse files

Create tensors on cuda rather than copying

parent 2a3b445d
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -52,7 +52,7 @@ class ICTBertModel(MegatronModule):
    def embed_query(self, query_tokens, query_attention_mask):
        """Embed a batch of tokens using the query model"""
        if self.use_query_model:
            query_types = torch.zeros(query_tokens.shape).type(torch.int64).cuda()
            query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
            query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types)
            return query_ict_logits
        else:
@@ -61,7 +61,7 @@ class ICTBertModel(MegatronModule):
    def embed_block(self, block_tokens, block_attention_mask):
        """Embed a batch of tokens using the block model"""
        if self.use_block_model:
            block_types = torch.zeros(block_tokens.shape).type(torch.int64).cuda()
            block_types = torch.cuda.LongTensor(*block_tokens.shape).fill_(0)
            block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types)
            return block_ict_logits
        else:
+2 −2
Original line number Diff line number Diff line
@@ -99,8 +99,8 @@ def forward_step(data_iterator, model):
    global_batch_size = int(batch_size * data_parallel_size)

    all_logits_shape = (int(global_batch_size), int(query_logits.shape[1]))
    all_query_logits = torch.zeros(all_logits_shape).type(query_logits.dtype).cuda()
    all_block_logits = all_query_logits.clone().cuda()
    all_query_logits = torch.cuda.FloatTensor(*all_logits_shape).type(query_logits.dtype).fill_(0.0)
    all_block_logits = all_query_logits.clone()

    # record this processes' data
    all_query_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = query_logits