Commit fc06e7a2 authored by zihanl's avatar zihanl
Browse files

update api prompt

parent 57c8996e
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@

import json
import torch
import requests
from nltk import word_tokenize
from megatron import mpu
from megatron import get_args
@@ -104,12 +105,12 @@ def generate_samples_by_calling_api():
    fname = open(args.sample_input_file, "r")
    test_sample_list = fname.readlines()
    # create output file
    fname_out = open(sample_output_file, "w")
    fname_out = open(args.sample_output_file, "w")

    # call the api to get the output generations
    for test_sample in test_sample_list:
        test_sample = test_sample.strip()
        splits = input_str.split("\t")
        splits = test_sample.split("\t")
        topic = splits[0]

        # prepare the inputs for the api