Loading megatron/text_generation_utils.py +5 −4 Original line number Diff line number Diff line Loading @@ -207,12 +207,13 @@ def generate_samples_eval(model, context, max_gen_length, eos_token_id): with torch.no_grad(): token_stream = get_token_stream(model, [context_tokens]) for counter, decode_tokens in enumerate(token_stream): if counter == args.out_seq_length: break decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() trim_decode_tokens = tokenizer.detokenize( decode_tokens)[raw_text_len:] if counter == args.out_seq_length: break return trim_decode_tokens Loading Loading
megatron/text_generation_utils.py +5 −4 Original line number Diff line number Diff line Loading @@ -207,12 +207,13 @@ def generate_samples_eval(model, context, max_gen_length, eos_token_id): with torch.no_grad(): token_stream = get_token_stream(model, [context_tokens]) for counter, decode_tokens in enumerate(token_stream): if counter == args.out_seq_length: break decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() trim_decode_tokens = tokenizer.detokenize( decode_tokens)[raw_text_len:] if counter == args.out_seq_length: break return trim_decode_tokens Loading