Loading tasks/knwl_dialo/prompt.py +121 −3 Original line number Diff line number Diff line Loading @@ -30,6 +30,121 @@ from tasks.knwl_dialo.utils import get_token_stream # from megatron.text_generation import generate_and_post_process def call_model_api(inputs): """Calling the model api to get the output generations""" # TODO # Implement the model api, and get output generations from the inputs # After that, return the output generations # outputs = call_model_api(inputs) # return outputs pass def read_prompts(prompt_path, prompt_type, n_example): """Read prompt data""" if prompt_type == "knowledge": # prompts for the knowledge generation prompt_examples_dict = {} # read prompt_path with open(prompt_path, "r") as f: for i, line in enumerate(f): line = line.strip() line_dict = json.loads(line) key = list(line_dict.keys())[0] if key not in prompt_examples_dict: prompt_examples = line_dict[key] prompt = "" for instance in prompt_examples: instance = instance.strip() prompt += instance + " \n" prompt_examples_dict[key] = prompt return prompt_examples_dict else: # prompts for the response generation # read prompt_path prompt = "" with open(prompt_path, "r") as f: prompt_examples = f.readlines() prompt_examples = prompt_examples[:n_example] for instance in prompt_examples: instance = instance.strip() prompt += instance + " \n" return prompt def generate_samples_by_calling_api(): """ Generate outputs by calling""" args = get_args() assert args.prompt_type in ["knowledge", "response"], \ "Please input a correct prompt type!" if args.prompt_type == "knowledge": # read knowledge generation prompts knwl_gen_prompt_dict = read_prompts( args.prompt_file, args.prompt_type, args.num_prompt_examples) else: resp_gen_prompt = read_prompts( args.prompt_file, args.prompt_type, args.num_prompt_examples) # read the test data fname = open(args.sample_input_file, "r") test_sample_list = fname.readlines() # create output file fname_out = open(sample_output_file, "w") # call the api to get the output generations for test_sample in test_sample_list: test_sample = test_sample.strip() splits = input_str.split("\t") topic = splits[0] # prepare the inputs for the api if args.prompt_type == "knowledge": # inputs = prompt + current test # get the prompt turns = splits[1].split(" [SEP] ") last_turn = turns[-1] key = topic + " " + last_turn inputs = knwl_gen_prompt_dict[key] # add current test inputs += "( " + last_turn + " ) " + topic + " =>" else: # inputs = prompt + current test # get the prompt inputs = resp_gen_prompt # add current test turns = splits[1].split(" [SEP] ") knowledge = splits[2] last_turn = turns[-1] last_turn = " ".join(word_tokenize(last_turn)) knowledge = " ".join(word_tokenize(knowledge)) knowledge = knowledge.strip() last_turn = last_turn.strip() inputs += "Topic: " + topic + ". " inputs += "User says: " + last_turn + " " inputs += "We know that: " + knowledge + " " inputs += "System replies:" # get the output generations from the api, # and write to the output file generations = call_model_api(inputs) fname_out.write(generations) fname_out.write("\n") fname.close() fname_out.close() def model_provider(pre_process=True, post_process=True): """Build the model.""" Loading Loading @@ -124,9 +239,7 @@ def generate_samples_by_prompting_input_from_file(model): # construct inputs for knowledge generation # then add the constructed inputs into the raw_text turns = splits[1].split(" [SEP] ") context = turns[-1] raw_text += "( " + context + " ) " + topic + " =>" raw_text += "( " + last_turn + " ) " + topic + " =>" else: # first add the prompt into the raw_text Loading Loading @@ -196,6 +309,11 @@ def generate_samples_by_prompting_input_from_file(model): def main(): args = get_args() if args.api_prompting: # obtain the generations by calling the api generate_samples_by_calling_api() return if args.num_layers_per_virtual_pipeline_stage is not None: print("Interleaved pipeline schedule is not yet supported for text generation.") exit() Loading tasks/main.py +2 −0 Original line number Diff line number Diff line Loading @@ -102,6 +102,8 @@ def get_tasks_args(parser): help='datapath for golden sentences') group.add_argument('--out-seq-length', type=int, default=100, help='output sequence length') group.add_argument('--api-prompt', default=False, action="store_true", help='setup model api for prompting') return parser Loading Loading
tasks/knwl_dialo/prompt.py +121 −3 Original line number Diff line number Diff line Loading @@ -30,6 +30,121 @@ from tasks.knwl_dialo.utils import get_token_stream # from megatron.text_generation import generate_and_post_process def call_model_api(inputs): """Calling the model api to get the output generations""" # TODO # Implement the model api, and get output generations from the inputs # After that, return the output generations # outputs = call_model_api(inputs) # return outputs pass def read_prompts(prompt_path, prompt_type, n_example): """Read prompt data""" if prompt_type == "knowledge": # prompts for the knowledge generation prompt_examples_dict = {} # read prompt_path with open(prompt_path, "r") as f: for i, line in enumerate(f): line = line.strip() line_dict = json.loads(line) key = list(line_dict.keys())[0] if key not in prompt_examples_dict: prompt_examples = line_dict[key] prompt = "" for instance in prompt_examples: instance = instance.strip() prompt += instance + " \n" prompt_examples_dict[key] = prompt return prompt_examples_dict else: # prompts for the response generation # read prompt_path prompt = "" with open(prompt_path, "r") as f: prompt_examples = f.readlines() prompt_examples = prompt_examples[:n_example] for instance in prompt_examples: instance = instance.strip() prompt += instance + " \n" return prompt def generate_samples_by_calling_api(): """ Generate outputs by calling""" args = get_args() assert args.prompt_type in ["knowledge", "response"], \ "Please input a correct prompt type!" if args.prompt_type == "knowledge": # read knowledge generation prompts knwl_gen_prompt_dict = read_prompts( args.prompt_file, args.prompt_type, args.num_prompt_examples) else: resp_gen_prompt = read_prompts( args.prompt_file, args.prompt_type, args.num_prompt_examples) # read the test data fname = open(args.sample_input_file, "r") test_sample_list = fname.readlines() # create output file fname_out = open(sample_output_file, "w") # call the api to get the output generations for test_sample in test_sample_list: test_sample = test_sample.strip() splits = input_str.split("\t") topic = splits[0] # prepare the inputs for the api if args.prompt_type == "knowledge": # inputs = prompt + current test # get the prompt turns = splits[1].split(" [SEP] ") last_turn = turns[-1] key = topic + " " + last_turn inputs = knwl_gen_prompt_dict[key] # add current test inputs += "( " + last_turn + " ) " + topic + " =>" else: # inputs = prompt + current test # get the prompt inputs = resp_gen_prompt # add current test turns = splits[1].split(" [SEP] ") knowledge = splits[2] last_turn = turns[-1] last_turn = " ".join(word_tokenize(last_turn)) knowledge = " ".join(word_tokenize(knowledge)) knowledge = knowledge.strip() last_turn = last_turn.strip() inputs += "Topic: " + topic + ". " inputs += "User says: " + last_turn + " " inputs += "We know that: " + knowledge + " " inputs += "System replies:" # get the output generations from the api, # and write to the output file generations = call_model_api(inputs) fname_out.write(generations) fname_out.write("\n") fname.close() fname_out.close() def model_provider(pre_process=True, post_process=True): """Build the model.""" Loading Loading @@ -124,9 +239,7 @@ def generate_samples_by_prompting_input_from_file(model): # construct inputs for knowledge generation # then add the constructed inputs into the raw_text turns = splits[1].split(" [SEP] ") context = turns[-1] raw_text += "( " + context + " ) " + topic + " =>" raw_text += "( " + last_turn + " ) " + topic + " =>" else: # first add the prompt into the raw_text Loading Loading @@ -196,6 +309,11 @@ def generate_samples_by_prompting_input_from_file(model): def main(): args = get_args() if args.api_prompting: # obtain the generations by calling the api generate_samples_by_calling_api() return if args.num_layers_per_virtual_pipeline_stage is not None: print("Interleaved pipeline schedule is not yet supported for text generation.") exit() Loading
tasks/main.py +2 −0 Original line number Diff line number Diff line Loading @@ -102,6 +102,8 @@ def get_tasks_args(parser): help='datapath for golden sentences') group.add_argument('--out-seq-length', type=int, default=100, help='output sequence length') group.add_argument('--api-prompt', default=False, action="store_true", help='setup model api for prompting') return parser Loading