Commit 5a6431f5 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

addressed comments

parent 5c2ce593
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -466,6 +466,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
    with torch.no_grad():
        context_length = context_lengths.min().item()

        # added eos_id to support the function generate_samples_eval that passes
        # eos_id as an argument and needs termination when that id id found.
        if hasattr(args, 'eos_id'):
            eos_id = args.eos_id
        else: