Loading megatron/text_generation_utils.py +40 −1 Original line number Diff line number Diff line Loading @@ -190,6 +190,41 @@ def generate_samples_input_from_file(model): raw_text = None context_count += 1 def generate_samples_eval(model, context, max_gen_length, eos_token_id, do_sample): # Generate samples for lm evaluation # NEED TO THINK ABOUT eos token args = get_args() tokenizer = get_tokenizer() raw_text_len = len(context) model.eval() context_tokens = tokenizer.tokenize(context) args.out_seq_length = max_gen_length + len(context_tokens) args.recompute = True #set this default value args.eos_id = eos_token_id if not do_sample: args.greedy = True else: # set similar to huggngface args.top_p = 1.0 args.temperature = 1.0 args.top_k = 50 with torch.no_grad(): token_stream = get_token_stream(model, [context_tokens]) for counter, decode_tokens in enumerate(token_stream): decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() trim_decode_tokens = tokenizer.detokenize( decode_tokens)[raw_text_len:] if counter == args.out_seq_length: break return trim_decode_tokens def generate_samples_interactive(model, print_frequency=24): Loading Loading @@ -438,6 +473,10 @@ def sample_sequence_batch(model, context_tokens, context_lengths, model.eval() with torch.no_grad(): context_length = context_lengths.min().item() if hasattr(args, 'eos_id'): eos_id = args.eos_id else: eos_id = tokenizer.eod counter = 0 Loading Loading
megatron/text_generation_utils.py +40 −1 Original line number Diff line number Diff line Loading @@ -190,6 +190,41 @@ def generate_samples_input_from_file(model): raw_text = None context_count += 1 def generate_samples_eval(model, context, max_gen_length, eos_token_id, do_sample): # Generate samples for lm evaluation # NEED TO THINK ABOUT eos token args = get_args() tokenizer = get_tokenizer() raw_text_len = len(context) model.eval() context_tokens = tokenizer.tokenize(context) args.out_seq_length = max_gen_length + len(context_tokens) args.recompute = True #set this default value args.eos_id = eos_token_id if not do_sample: args.greedy = True else: # set similar to huggngface args.top_p = 1.0 args.temperature = 1.0 args.top_k = 50 with torch.no_grad(): token_stream = get_token_stream(model, [context_tokens]) for counter, decode_tokens in enumerate(token_stream): decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() trim_decode_tokens = tokenizer.detokenize( decode_tokens)[raw_text_len:] if counter == args.out_seq_length: break return trim_decode_tokens def generate_samples_interactive(model, print_frequency=24): Loading Loading @@ -438,6 +473,10 @@ def sample_sequence_batch(model, context_tokens, context_lengths, model.eval() with torch.no_grad(): context_length = context_lengths.min().item() if hasattr(args, 'eos_id'): eos_id = args.eos_id else: eos_id = tokenizer.eod counter = 0 Loading