Loading megatron/text_generation_utils.py +2 −0 Original line number Diff line number Diff line Loading @@ -466,6 +466,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths, with torch.no_grad(): context_length = context_lengths.min().item() # added eos_id to support the function generate_samples_eval that passes # eos_id as an argument and needs termination when that id id found. if hasattr(args, 'eos_id'): eos_id = args.eos_id else: Loading Loading
megatron/text_generation_utils.py +2 −0 Original line number Diff line number Diff line Loading @@ -466,6 +466,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths, with torch.no_grad(): context_length = context_lengths.min().item() # added eos_id to support the function generate_samples_eval that passes # eos_id as an argument and needs termination when that id id found. if hasattr(args, 'eos_id'): eos_id = args.eos_id else: Loading