Commit 5c2ce593 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

addressed comments

parent 44bfcb32
Loading
Loading
Loading
Loading
+5 −4
Original line number Diff line number Diff line
@@ -207,12 +207,13 @@ def generate_samples_eval(model, context, max_gen_length, eos_token_id):
    with torch.no_grad():
        token_stream = get_token_stream(model, [context_tokens])
        for counter, decode_tokens in enumerate(token_stream):
            if counter == args.out_seq_length:
                break

    decode_tokens, _ = decode_tokens
    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
    trim_decode_tokens = tokenizer.detokenize(
        decode_tokens)[raw_text_len:]
            if counter == args.out_seq_length:
                break
 
    return trim_decode_tokens