Loading generate_samples.py +1 −1 Original line number Diff line number Diff line Loading @@ -417,7 +417,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask context_length += 1 counter += 1 done_token = (prev == eos_id).byte() done_token = (prev == eos_id).byte() & started.byte() just_finished = (done_token & ~is_done).bool() lengths[just_finished.view(-1)] = context_length was_done = is_done Loading Loading
generate_samples.py +1 −1 Original line number Diff line number Diff line Loading @@ -417,7 +417,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask context_length += 1 counter += 1 done_token = (prev == eos_id).byte() done_token = (prev == eos_id).byte() & started.byte() just_finished = (done_token & ~is_done).bool() lengths[just_finished.view(-1)] = context_length was_done = is_done Loading