Loading megatron/text_generation_utils.py +1 −10 Original line number Diff line number Diff line Loading @@ -190,7 +190,7 @@ def generate_samples_input_from_file(model): raw_text = None context_count += 1 def generate_samples_eval(model, context, max_gen_length, eos_token_id, do_sample): def generate_samples_eval(model, context, max_gen_length, eos_token_id): # Generate samples for lm evaluation # NEED TO THINK ABOUT eos token Loading @@ -202,17 +202,8 @@ def generate_samples_eval(model, context, max_gen_length, eos_token_id, do_sampl context_tokens = tokenizer.tokenize(context) args.out_seq_length = max_gen_length + len(context_tokens) args.recompute = True #set this default value args.eos_id = eos_token_id if not do_sample: args.greedy = True else: # set similar to huggngface args.top_p = 1.0 args.temperature = 1.0 args.top_k = 50 with torch.no_grad(): token_stream = get_token_stream(model, [context_tokens]) for counter, decode_tokens in enumerate(token_stream): Loading Loading
megatron/text_generation_utils.py +1 −10 Original line number Diff line number Diff line Loading @@ -190,7 +190,7 @@ def generate_samples_input_from_file(model): raw_text = None context_count += 1 def generate_samples_eval(model, context, max_gen_length, eos_token_id, do_sample): def generate_samples_eval(model, context, max_gen_length, eos_token_id): # Generate samples for lm evaluation # NEED TO THINK ABOUT eos token Loading @@ -202,17 +202,8 @@ def generate_samples_eval(model, context, max_gen_length, eos_token_id, do_sampl context_tokens = tokenizer.tokenize(context) args.out_seq_length = max_gen_length + len(context_tokens) args.recompute = True #set this default value args.eos_id = eos_token_id if not do_sample: args.greedy = True else: # set similar to huggngface args.top_p = 1.0 args.temperature = 1.0 args.top_k = 50 with torch.no_grad(): token_stream = get_token_stream(model, [context_tokens]) for counter, decode_tokens in enumerate(token_stream): Loading