Loading megatron/text_generation_utils.py +13 −0 Original line number Diff line number Diff line Loading @@ -151,6 +151,7 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len return tokens[:, :context_length] def generate(model, sentences=None, max_len=0): model.eval() if torch.distributed.get_rank() == 0: context_tokens_tensor, context_length_tensor = tokenize_batch(sentences) send_generate_info(context_tokens_tensor, context_length_tensor, max_len) Loading @@ -168,6 +169,18 @@ def generate(model, sentences=None, max_len=0): resp_sentences.append(tokenizer.detokenize(decode_token)) return resp_sentences def generate_samples_eval(model, context, max_gen_length, eos_token_id): """ This function is here to provide an a matching API for a legacy task This implementation hasn't been tested yet to make sure it matches """ assert False, "Implementation untested" args = get_args() args.eos_id = eos_token_id raw_text_len = len(context) resp_sentences = generate(model, [context], max_gen_length) return resp_sentences[0][raw_text_len:] def switch(val1, val2, boolean): boolean = boolean.type_as(val1) return (1 - boolean) * val1 + boolean * val2 Loading Loading
megatron/text_generation_utils.py +13 −0 Original line number Diff line number Diff line Loading @@ -151,6 +151,7 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len return tokens[:, :context_length] def generate(model, sentences=None, max_len=0): model.eval() if torch.distributed.get_rank() == 0: context_tokens_tensor, context_length_tensor = tokenize_batch(sentences) send_generate_info(context_tokens_tensor, context_length_tensor, max_len) Loading @@ -168,6 +169,18 @@ def generate(model, sentences=None, max_len=0): resp_sentences.append(tokenizer.detokenize(decode_token)) return resp_sentences def generate_samples_eval(model, context, max_gen_length, eos_token_id): """ This function is here to provide an a matching API for a legacy task This implementation hasn't been tested yet to make sure it matches """ assert False, "Implementation untested" args = get_args() args.eos_id = eos_token_id raw_text_len = len(context) resp_sentences = generate(model, [context], max_gen_length) return resp_sentences[0][raw_text_len:] def switch(val1, val2, boolean): boolean = boolean.type_as(val1) return (1 - boolean) * val1 + boolean * val2 Loading