Commit 6013e23c authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Dedup for other tasks added

parent b08b5edc
Loading
Loading
Loading
Loading
+70 −24
Original line number Diff line number Diff line
@@ -162,7 +162,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):

    # check if the text has only been trimmed
    trimmed = 0
    if len(text_buf_ngram_free) == 1 and len(text_buf_ngram_free[0]) == \
    if len(text_buf_ngram_free) == 1 and len(text_buf_ngram_free[0]) < \
        len(myjson[key]):
        trimmed = 1

@@ -201,21 +201,57 @@ def process_task_lambda(args, task_file, ngrams):
    print(" Entities in ngrams {}".format(len(ngrams)), flush=True)


# Build ngrams for the squad v2 dataset
def process_task_squad(args, ngrams):
# Build ngrams for the dataset of the given task
def process_task(args, task_name, ngrams):
    print(' reading from {} and computing ngrams'.format('import datasets'))
    # using squad data from datasets
    print(" Current entities in ngrams {}".format(len(ngrams)), flush=True)
    # using validation/test data from datasets
    from datasets import load_dataset
    squad_v2 = load_dataset('squad_v2', split='validation')

    for line in squad_v2:
    entities_in_ngrams = len(ngrams)

    # load the dataset
    if task_name == 'squad':
        dataset = load_dataset('squad_v2', split='validation')
    elif task_name == 'natural_questions':
        dataset = load_dataset('natural_questions', split='validation')
    elif task_name == 'triviaqa':
        dataset = load_dataset('trivia_qa', 'unfiltered', split='test')
    elif task_name == 'webqa':
        dataset = load_dataset('web_questions', split='test')
    elif task_name == 'race':
        dataset = load_dataset('race', 'all', split='test')
    elif task_name == 'drop':
        dataset = load_dataset('drop', split='validation')
    elif task_name == 'coqa':
        dataset = load_dataset('coqa', split='validation')
    elif task_name == 'piqa':
        dataset = load_dataset('piqa', split='test')
    else:
        print("Invalid task name: {}".format(task_name), flush=True)
        return

    # read the dataset and add to ngrams
    for line in dataset:
        try:
            if task_name in ['squad', 'triviaqa', 'webqa', 'race', 'drop']:
                text = line['question']
                compute_ngrams_insert_dict(args, text, ngrams)
            elif task_name == 'natural_questions':
                text = line['question']['text']
                compute_ngrams_insert_dict(args, text, ngrams)
            elif task_name == 'coqa':
                all_questions = line['questions']
                for question in all_questions:
                    compute_ngrams_insert_dict(args, question, ngrams)
            elif task_name == 'piqa':
                text = line['goal']
                compute_ngrams_insert_dict(args, text, ngrams)
        except Exception as e:
            print('Error:', e)
    print(" Entities in ngrams {}".format(len(ngrams)), flush=True)

    print(" After task {} entities in ngrams {}, added {}".format(task_name, \
            len(ngrams), len(ngrams) - entities_in_ngrams), flush=True)

if __name__ == '__main__':

@@ -227,7 +263,8 @@ if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--tasks', nargs = '*', required=True, default=None, \
                        help = 'Tasks to use for deduplication: currently '
                        ' suuport [lambada, squad]')
                        ' suuport [lambada, squad, natural_questions,'
                        ' triviaqa, webqa, race, drop, coqa, and piqa]')
    parser.add_argument('--lambada-path', type=str, default=None,
                       help='Only Lambada task needs the path')
    parser.add_argument('--dedup-dataset', nargs = '*', default=None,
@@ -249,13 +286,16 @@ if __name__ == '__main__':

    # Build ngrams
    ngrams = {}
    start_time = time.time()
    for _, task_name in enumerate(args.tasks):
        print('Task: {}'.format(task_name), flush=True)
        if task_name == 'lambada':
            assert args.lambada_path is not None
            process_task_lambda(args, args.lambada_path, ngrams)
        if task_name == 'squad':
            process_task_squad(args, ngrams)
        else:
            process_task(args, task_name, ngrams)

    print(" Taken time {:.2f}".format(time.time() - start_time), flush=True)

    # get the range of the size of the ngrams
    ngrams_freq = {}
@@ -263,8 +303,8 @@ if __name__ == '__main__':
        length = len(ngram_key.split())
        ngrams_freq[length] = ngrams_freq[length] + 1 if length in \
            ngrams_freq else 1
    ngrams_freq_sorted = sorted(ngrams_freq.items(), key=lambda item: item[1])

    ngrams_freq_sorted = sorted(ngrams_freq.items(), key=lambda item: item[0])
    print(" Ngram frequencies: {}".format(ngrams_freq_sorted), flush=True)
    print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\
            len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\
@@ -276,7 +316,10 @@ if __name__ == '__main__':

    counter = 0
    start_time = time.time()

    if args.output is not None:
        out_f = open(args.output, 'wb')

    splitted, ignored, split_mt_thld, trimmed_count = 0, 0, 0, 0

    assert len(args.dedup_dataset) == 2
@@ -299,7 +342,7 @@ if __name__ == '__main__':
            trimmed_count += trimmed

            if len(text_buf_ngram_free) > 1:
                splitted += (len(text_buf_ngram_free) - 1)
                splitted += 1
            if len(text_buf_ngram_free) == 0:
                ignored += 1
            # more than 10 splits ignored
@@ -307,9 +350,10 @@ if __name__ == '__main__':
                text_buf_ngram_free = []
                split_mt_thld += 1

            if args.output is not None:
                for i in range(len(text_buf_ngram_free)):
                split_id_string = id_prefix + '-{:010d}'.format(int(counter)) \
                    + '-{:010d}'.format(int(i))
                    split_id_string = id_prefix + '-{:010d}'.format(int(\
                        counter)) + '-{:010d}'.format(int(i))
                    outjson = json.dumps({"text":text_buf_ngram_free[i],
                        id_prefix+"_split_id":split_id_string},
                        ensure_ascii=False)
@@ -322,7 +366,9 @@ if __name__ == '__main__':
        except Exception as e:
            print('Error:', e)

    if args.output is not None:
        out_f.close()

    fin.close()

    print("Deduped file written to: {}".format(args.output), flush=True)