Loading megatron/text_generation_utils.py +2 −2 Original line number Diff line number Diff line Loading @@ -270,7 +270,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, lengths = torch.ones([batch_size]).long().cuda() * maxlen while context_length <= (maxlen): while context_length < maxlen: types2use = None if counter == 0: tokens2use = tokens[:, :context_length] Loading Loading @@ -316,7 +316,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, if output_logits is None: output_context = F.log_softmax(output[:, :context_length, :], 2) indices = torch.unsqueeze(tokens[:, :context_length],2) indices = torch.unsqueeze(tokens[:, 1:context_length+1],2) output_logits = torch.gather(output_context, 2, indices).squeeze(2) else: indices = torch.unsqueeze(new_tokens,1).unsqueeze(2) Loading Loading
megatron/text_generation_utils.py +2 −2 Original line number Diff line number Diff line Loading @@ -270,7 +270,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, lengths = torch.ones([batch_size]).long().cuda() * maxlen while context_length <= (maxlen): while context_length < maxlen: types2use = None if counter == 0: tokens2use = tokens[:, :context_length] Loading Loading @@ -316,7 +316,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, if output_logits is None: output_context = F.log_softmax(output[:, :context_length, :], 2) indices = torch.unsqueeze(tokens[:, :context_length],2) indices = torch.unsqueeze(tokens[:, 1:context_length+1],2) output_logits = torch.gather(output_context, 2, indices).squeeze(2) else: indices = torch.unsqueeze(new_tokens,1).unsqueeze(2) Loading