Commit ccf0eae2 authored by zihanl's avatar zihanl
Browse files

use new text generation

parent b3cd8a47
Loading
Loading
Loading
Loading
+32 −41
Original line number Diff line number Diff line
@@ -26,19 +26,26 @@ from megatron.model import GPTModel
from megatron.training import get_model
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from tasks.knwl_dialo.utils import get_token_stream
# from megatron.text_generation import generate_and_post_process
from megatron.text_generation import generate_and_post_process


def call_model_api(inputs):
def call_model_api(inputs, tokens_to_generate):
    """Calling the model api to get the output generations"""
    # TODO
    # Implement the model api, and get output generations from the inputs
    # After that, return the output generations
    
    # outputs = call_model_api(inputs)
    # return outputs
    pass
    args = get_args()

    # The following is an example of using the Megatron API
    # You can also implement your own API function to place this part
    headers = {'Content-Type': 'application/json; charset=UTF-8'}
    data = {"prompts": [inputs], "tokens_to_generate": tokens_to_generate, "top_k": 1}
    data_json = json.dumps(data)
    outputs = requests.put(args.megatron_api_url, headers=headers, data=data_json).json()["text"][0]

    input_len = len(inputs)
    outputs = outputs[input_len:]
    outputs = outputs.split("\n")[0].strip()
    
    return outputs


def read_prompts(prompt_path, prompt_type, n_example):
@@ -107,7 +114,7 @@ def generate_samples_by_calling_api():

        # prepare the inputs for the api
        if args.prompt_type == "knowledge":
            # inputs = prompt + current test
            ## inputs = prompt + current test
            # get the prompt
            turns = splits[1].split(" [SEP] ")
            last_turn = turns[-1]
@@ -216,7 +223,6 @@ def generate_samples_by_prompting_input_from_file(model):
                instance = instance.strip()
                prompt += instance + " \n"

    context_count = 0
    input_pos = 0
    model.eval()
    # perform prompting
@@ -261,47 +267,32 @@ def generate_samples_by_prompting_input_from_file(model):

                input_pos += 1
                raw_text_len = len(raw_text)
                context_tokens = tokenizer.tokenize(raw_text)
            
            else:
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
                # raw_text = "EMPTY TEXT"
                raw_text = "EMPTY TEXT"

            if input_pos % 100 == 0:
                print_rank_0("input_pos: %d" % input_pos)

            # get the generation outputs (in decode_tokens)
            token_stream = get_token_stream(model, [context_tokens])
            for _, decode_tokens in enumerate(token_stream):
                pass
            # outputs = generate_and_post_process(
            #             model=model, 
            #             prompts=[raw_text], 
            #             tokens_to_generate=args.out_seq_length,
            #             top_k_sampling=1)
            # prompts_plus_generations = outputs[0]
            outputs = generate_and_post_process(
                        model=model, 
                        prompts=[raw_text], 
                        tokens_to_generate=args.out_seq_length,
                        top_k_sampling=1)
            prompts_plus_generations = outputs[0]
            prompts_plus_generations = prompts_plus_generations[0]

            # write the generated output to the output file
            if mpu.get_tensor_model_parallel_rank() == 0:
                if mpu.is_pipeline_first_stage():
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    trim_decode_tokens = tokenizer.detokenize(
                        decode_tokens)[raw_text_len:]
                    
                    generated_output = trim_decode_tokens.split("\n")[0]
                    generated_output = generated_output.strip()
                    fname_out.write(generated_output)
                    fname_out.write("\n")

                    # generations = prompts_plus_generations[raw_text_len:]
                    # generations = generations.split("\n")[0]
                    # generations = generations.strip()
                    # fname_out.write(generations)
                    # fname_out.write("\n")
                    generations = prompts_plus_generations[raw_text_len:]
                    generations = generations.split("\n")[0]
                    generations = generations.strip()
                    fname_out.write(generations)
                    fname_out.write("\n")

            raw_text = None
            context_count += 1
            if input_pos == input_count:
                return

@@ -309,7 +300,7 @@ def generate_samples_by_prompting_input_from_file(model):
def main():

    args = get_args()
    if args.api_prompting:
    if args.api_prompt:
        # obtain the generations by calling the api
        generate_samples_by_calling_api()
        return
@@ -319,7 +310,7 @@ def main():
        exit()

    # Set up model and load checkpoint.
    model = get_model(model_provider)
    model = get_model(model_provider, wrap_with_ddp=False)
    if args.load is not None:
        _ = load_checkpoint(model, None, None)