Commit 5f433016 authored by zihanl's avatar zihanl
Browse files

add api option

parent fec9668d
Loading
Loading
Loading
Loading
+121 −3
Original line number Diff line number Diff line
@@ -30,6 +30,121 @@ from tasks.knwl_dialo.utils import get_token_stream
# from megatron.text_generation import generate_and_post_process


def call_model_api(inputs):
    """Calling the model api to get the output generations"""
    # TODO
    # Implement the model api, and get output generations from the inputs
    # After that, return the output generations

    # outputs = call_model_api(inputs)
    # return outputs
    pass


def read_prompts(prompt_path, prompt_type, n_example):
    """Read prompt data"""

    if prompt_type == "knowledge":
        # prompts for the knowledge generation
        prompt_examples_dict = {}
        # read prompt_path
        with open(prompt_path, "r") as f:
            for i, line in enumerate(f):
                line = line.strip()
                line_dict = json.loads(line)
                key = list(line_dict.keys())[0]
                
                if key not in prompt_examples_dict:
                    prompt_examples = line_dict[key]
                    prompt = ""
                    for instance in prompt_examples:
                        instance = instance.strip()
                        prompt += instance + " \n"
                    prompt_examples_dict[key] = prompt

        return prompt_examples_dict

    else:
        # prompts for the response generation
        # read prompt_path
        prompt = ""
        with open(prompt_path, "r") as f:
            prompt_examples = f.readlines()
            prompt_examples = prompt_examples[:n_example]
            for instance in prompt_examples:
                instance = instance.strip()
                prompt += instance + " \n"

        return prompt


def generate_samples_by_calling_api():
    """ Generate outputs by calling"""
    args = get_args()
    assert args.prompt_type in ["knowledge", "response"], \
                "Please input a correct prompt type!"

    if args.prompt_type == "knowledge":
        # read knowledge generation prompts
        knwl_gen_prompt_dict = read_prompts(
            args.prompt_file, args.prompt_type, args.num_prompt_examples)
        
    else:
        resp_gen_prompt = read_prompts(
            args.prompt_file, args.prompt_type, args.num_prompt_examples)

    # read the test data
    fname = open(args.sample_input_file, "r")
    test_sample_list = fname.readlines()
    # create output file
    fname_out = open(sample_output_file, "w")

    # call the api to get the output generations
    for test_sample in test_sample_list:
        test_sample = test_sample.strip()
        splits = input_str.split("\t")
        topic = splits[0]

        # prepare the inputs for the api
        if args.prompt_type == "knowledge":
            # inputs = prompt + current test
            # get the prompt
            turns = splits[1].split(" [SEP] ")
            last_turn = turns[-1]
            key = topic + " " + last_turn
            inputs = knwl_gen_prompt_dict[key]

            # add current test
            inputs += "( " + last_turn + " ) " + topic + " =>"

        else:
            # inputs = prompt + current test
            # get the prompt
            inputs = resp_gen_prompt

            # add current test
            turns = splits[1].split(" [SEP] ")
            knowledge = splits[2]
            last_turn = turns[-1]
            last_turn = " ".join(word_tokenize(last_turn))
            knowledge = " ".join(word_tokenize(knowledge))
            knowledge = knowledge.strip()
            last_turn = last_turn.strip()
            inputs += "Topic: " + topic + ". "
            inputs += "User says: " + last_turn + " "
            inputs += "We know that: " + knowledge + " "
            inputs += "System replies:"

        # get the output generations from the api, 
        # and write to the output file
        generations = call_model_api(inputs)
        fname_out.write(generations)
        fname_out.write("\n")

    fname.close()
    fname_out.close()


def model_provider(pre_process=True, post_process=True):
    """Build the model."""

@@ -124,9 +239,7 @@ def generate_samples_by_prompting_input_from_file(model):

                    # construct inputs for knowledge generation
                    # then add the constructed inputs into the raw_text
                    turns = splits[1].split(" [SEP] ")
                    context = turns[-1]
                    raw_text += "( " + context + " ) " + topic + " =>"
                    raw_text += "( " + last_turn + " ) " + topic + " =>"
                
                else:
                    # first add the prompt into the raw_text
@@ -196,6 +309,11 @@ def generate_samples_by_prompting_input_from_file(model):
def main():

    args = get_args()
    if args.api_prompting:
        # obtain the generations by calling the api
        generate_samples_by_calling_api()
        return

    if args.num_layers_per_virtual_pipeline_stage is not None:
        print("Interleaved pipeline schedule is not yet supported for text generation.")
        exit()
+2 −0
Original line number Diff line number Diff line
@@ -102,6 +102,8 @@ def get_tasks_args(parser):
                       help='datapath for golden sentences')
    group.add_argument('--out-seq-length', type=int, default=100,
                       help='output sequence length')
    group.add_argument('--api-prompt', default=False, action="store_true",
                       help='setup model api for prompting')

    return parser