Commit 44bfcb32 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

modified the params

parent 045959cb
Loading
Loading
Loading
Loading
+1 −10
Original line number Diff line number Diff line
@@ -190,7 +190,7 @@ def generate_samples_input_from_file(model):
            raw_text = None
            context_count += 1

def generate_samples_eval(model, context, max_gen_length, eos_token_id, do_sample):
def generate_samples_eval(model, context, max_gen_length, eos_token_id):
    # Generate samples for lm evaluation
    # NEED TO THINK ABOUT eos token

@@ -202,17 +202,8 @@ def generate_samples_eval(model, context, max_gen_length, eos_token_id, do_sampl

    context_tokens = tokenizer.tokenize(context)
    args.out_seq_length = max_gen_length + len(context_tokens)
    args.recompute = True #set this default value
    args.eos_id = eos_token_id

    if not do_sample:
        args.greedy = True
    else:
        # set similar to huggngface
        args.top_p = 1.0
        args.temperature = 1.0
        args.top_k = 50

    with torch.no_grad():
        token_stream = get_token_stream(model, [context_tokens])
        for counter, decode_tokens in enumerate(token_stream):