Loading megatron/text_generation_utils.py +38 −1 Original line number Diff line number Diff line Loading @@ -190,6 +190,37 @@ def generate_samples_input_from_file(model): raw_text = None context_count += 1 # We added this function to support the tasks evaluation such as squad # and drop in the https://github.com/EleutherAI/lm-evaluation-harness # codebase. The lm-evaluation-harness code can now call this function # similar to their current generate function call used for gpt style models. def generate_samples_eval(model, context, max_gen_length, eos_token_id): # Generate samples for lm evaluation # NEED TO THINK ABOUT eos token args = get_args() tokenizer = get_tokenizer() raw_text_len = len(context) model.eval() context_tokens = tokenizer.tokenize(context) args.out_seq_length = max_gen_length + len(context_tokens) args.eos_id = eos_token_id with torch.no_grad(): token_stream = get_token_stream(model, [context_tokens]) for counter, decode_tokens in enumerate(token_stream): if counter == args.out_seq_length: break decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() trim_decode_tokens = tokenizer.detokenize( decode_tokens)[raw_text_len:] return trim_decode_tokens def generate_samples_interactive(model, print_frequency=24): Loading Loading @@ -438,6 +469,12 @@ def sample_sequence_batch(model, context_tokens, context_lengths, model.eval() with torch.no_grad(): context_length = context_lengths.min().item() # added eos_id to support the function generate_samples_eval that passes # eos_id as an argument and needs termination when that id id found. if hasattr(args, 'eos_id'): eos_id = args.eos_id else: eos_id = tokenizer.eod counter = 0 Loading Loading
megatron/text_generation_utils.py +38 −1 Original line number Diff line number Diff line Loading @@ -190,6 +190,37 @@ def generate_samples_input_from_file(model): raw_text = None context_count += 1 # We added this function to support the tasks evaluation such as squad # and drop in the https://github.com/EleutherAI/lm-evaluation-harness # codebase. The lm-evaluation-harness code can now call this function # similar to their current generate function call used for gpt style models. def generate_samples_eval(model, context, max_gen_length, eos_token_id): # Generate samples for lm evaluation # NEED TO THINK ABOUT eos token args = get_args() tokenizer = get_tokenizer() raw_text_len = len(context) model.eval() context_tokens = tokenizer.tokenize(context) args.out_seq_length = max_gen_length + len(context_tokens) args.eos_id = eos_token_id with torch.no_grad(): token_stream = get_token_stream(model, [context_tokens]) for counter, decode_tokens in enumerate(token_stream): if counter == args.out_seq_length: break decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() trim_decode_tokens = tokenizer.detokenize( decode_tokens)[raw_text_len:] return trim_decode_tokens def generate_samples_interactive(model, print_frequency=24): Loading Loading @@ -438,6 +469,12 @@ def sample_sequence_batch(model, context_tokens, context_lengths, model.eval() with torch.no_grad(): context_length = context_lengths.min().item() # added eos_id to support the function generate_samples_eval that passes # eos_id as an argument and needs termination when that id id found. if hasattr(args, 'eos_id'): eos_id = args.eos_id else: eos_id = tokenizer.eod counter = 0 Loading