Commit 41df5ff7 authored by rprenger's avatar rprenger
Browse files

Fixing bug where the wrong token was used to index the log probabilities

parent e0bf5199
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -270,7 +270,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
        
        lengths = torch.ones([batch_size]).long().cuda() * maxlen

        while context_length <= (maxlen):
        while context_length < maxlen:
            types2use = None
            if counter == 0:
                tokens2use = tokens[:, :context_length]
@@ -316,7 +316,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                
                if output_logits is None:
                    output_context = F.log_softmax(output[:, :context_length, :], 2)
                    indices = torch.unsqueeze(tokens[:, :context_length],2)
                    indices = torch.unsqueeze(tokens[:, 1:context_length+1],2)
                    output_logits = torch.gather(output_context, 2, indices).squeeze(2)
                else:
                    indices = torch.unsqueeze(new_tokens,1).unsqueeze(2)