Commit 716a3243 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main_generate' into 'main'

lm evaluation

See merge request ADLR/megatron-lm!262
parents 7a5768ac e5ec27d7
Loading
Loading
Loading
Loading
+38 −1
Original line number Diff line number Diff line
@@ -190,6 +190,37 @@ def generate_samples_input_from_file(model):
            raw_text = None
            context_count += 1

# We added this function to support the tasks evaluation such as squad
# and drop in the https://github.com/EleutherAI/lm-evaluation-harness 
# codebase. The lm-evaluation-harness code can now call this function
# similar to their current generate function call used for gpt style models.
def generate_samples_eval(model, context, max_gen_length, eos_token_id):
    # Generate samples for lm evaluation
    # NEED TO THINK ABOUT eos token

    args = get_args()
    tokenizer = get_tokenizer()

    raw_text_len = len(context)
    model.eval()

    context_tokens = tokenizer.tokenize(context)
    args.out_seq_length = max_gen_length + len(context_tokens)
    args.eos_id = eos_token_id

    with torch.no_grad():
        token_stream = get_token_stream(model, [context_tokens])
        for counter, decode_tokens in enumerate(token_stream):
            if counter == args.out_seq_length:
                break

    decode_tokens, _ = decode_tokens
    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
    trim_decode_tokens = tokenizer.detokenize(
        decode_tokens)[raw_text_len:]
 
    return trim_decode_tokens


def generate_samples_interactive(model, print_frequency=24):

@@ -438,6 +469,12 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()

        # added eos_id to support the function generate_samples_eval that passes
        # eos_id as an argument and needs termination when that id id found.
        if hasattr(args, 'eos_id'):
            eos_id = args.eos_id
        else:
            eos_id = tokenizer.eod

        counter = 0