Loading tools/openwebtext/filter_ngrams.py +70 −24 Original line number Diff line number Diff line Loading @@ -162,7 +162,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): # check if the text has only been trimmed trimmed = 0 if len(text_buf_ngram_free) == 1 and len(text_buf_ngram_free[0]) == \ if len(text_buf_ngram_free) == 1 and len(text_buf_ngram_free[0]) < \ len(myjson[key]): trimmed = 1 Loading Loading @@ -201,21 +201,57 @@ def process_task_lambda(args, task_file, ngrams): print(" Entities in ngrams {}".format(len(ngrams)), flush=True) # Build ngrams for the squad v2 dataset def process_task_squad(args, ngrams): # Build ngrams for the dataset of the given task def process_task(args, task_name, ngrams): print(' reading from {} and computing ngrams'.format('import datasets')) # using squad data from datasets print(" Current entities in ngrams {}".format(len(ngrams)), flush=True) # using validation/test data from datasets from datasets import load_dataset squad_v2 = load_dataset('squad_v2', split='validation') for line in squad_v2: entities_in_ngrams = len(ngrams) # load the dataset if task_name == 'squad': dataset = load_dataset('squad_v2', split='validation') elif task_name == 'natural_questions': dataset = load_dataset('natural_questions', split='validation') elif task_name == 'triviaqa': dataset = load_dataset('trivia_qa', 'unfiltered', split='test') elif task_name == 'webqa': dataset = load_dataset('web_questions', split='test') elif task_name == 'race': dataset = load_dataset('race', 'all', split='test') elif task_name == 'drop': dataset = load_dataset('drop', split='validation') elif task_name == 'coqa': dataset = load_dataset('coqa', split='validation') elif task_name == 'piqa': dataset = load_dataset('piqa', split='test') else: print("Invalid task name: {}".format(task_name), flush=True) return # read the dataset and add to ngrams for line in dataset: try: if task_name in ['squad', 'triviaqa', 'webqa', 'race', 'drop']: text = line['question'] compute_ngrams_insert_dict(args, text, ngrams) elif task_name == 'natural_questions': text = line['question']['text'] compute_ngrams_insert_dict(args, text, ngrams) elif task_name == 'coqa': all_questions = line['questions'] for question in all_questions: compute_ngrams_insert_dict(args, question, ngrams) elif task_name == 'piqa': text = line['goal'] compute_ngrams_insert_dict(args, text, ngrams) except Exception as e: print('Error:', e) print(" Entities in ngrams {}".format(len(ngrams)), flush=True) print(" After task {} entities in ngrams {}, added {}".format(task_name, \ len(ngrams), len(ngrams) - entities_in_ngrams), flush=True) if __name__ == '__main__': Loading @@ -227,7 +263,8 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--tasks', nargs = '*', required=True, default=None, \ help = 'Tasks to use for deduplication: currently ' ' suuport [lambada, squad]') ' suuport [lambada, squad, natural_questions,' ' triviaqa, webqa, race, drop, coqa, and piqa]') parser.add_argument('--lambada-path', type=str, default=None, help='Only Lambada task needs the path') parser.add_argument('--dedup-dataset', nargs = '*', default=None, Loading @@ -249,13 +286,16 @@ if __name__ == '__main__': # Build ngrams ngrams = {} start_time = time.time() for _, task_name in enumerate(args.tasks): print('Task: {}'.format(task_name), flush=True) if task_name == 'lambada': assert args.lambada_path is not None process_task_lambda(args, args.lambada_path, ngrams) if task_name == 'squad': process_task_squad(args, ngrams) else: process_task(args, task_name, ngrams) print(" Taken time {:.2f}".format(time.time() - start_time), flush=True) # get the range of the size of the ngrams ngrams_freq = {} Loading @@ -263,8 +303,8 @@ if __name__ == '__main__': length = len(ngram_key.split()) ngrams_freq[length] = ngrams_freq[length] + 1 if length in \ ngrams_freq else 1 ngrams_freq_sorted = sorted(ngrams_freq.items(), key=lambda item: item[1]) ngrams_freq_sorted = sorted(ngrams_freq.items(), key=lambda item: item[0]) print(" Ngram frequencies: {}".format(ngrams_freq_sorted), flush=True) print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\ len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\ Loading @@ -276,7 +316,10 @@ if __name__ == '__main__': counter = 0 start_time = time.time() if args.output is not None: out_f = open(args.output, 'wb') splitted, ignored, split_mt_thld, trimmed_count = 0, 0, 0, 0 assert len(args.dedup_dataset) == 2 Loading @@ -299,7 +342,7 @@ if __name__ == '__main__': trimmed_count += trimmed if len(text_buf_ngram_free) > 1: splitted += (len(text_buf_ngram_free) - 1) splitted += 1 if len(text_buf_ngram_free) == 0: ignored += 1 # more than 10 splits ignored Loading @@ -307,9 +350,10 @@ if __name__ == '__main__': text_buf_ngram_free = [] split_mt_thld += 1 if args.output is not None: for i in range(len(text_buf_ngram_free)): split_id_string = id_prefix + '-{:010d}'.format(int(counter)) \ + '-{:010d}'.format(int(i)) split_id_string = id_prefix + '-{:010d}'.format(int(\ counter)) + '-{:010d}'.format(int(i)) outjson = json.dumps({"text":text_buf_ngram_free[i], id_prefix+"_split_id":split_id_string}, ensure_ascii=False) Loading @@ -322,7 +366,9 @@ if __name__ == '__main__': except Exception as e: print('Error:', e) if args.output is not None: out_f.close() fin.close() print("Deduped file written to: {}".format(args.output), flush=True) Loading Loading
tools/openwebtext/filter_ngrams.py +70 −24 Original line number Diff line number Diff line Loading @@ -162,7 +162,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): # check if the text has only been trimmed trimmed = 0 if len(text_buf_ngram_free) == 1 and len(text_buf_ngram_free[0]) == \ if len(text_buf_ngram_free) == 1 and len(text_buf_ngram_free[0]) < \ len(myjson[key]): trimmed = 1 Loading Loading @@ -201,21 +201,57 @@ def process_task_lambda(args, task_file, ngrams): print(" Entities in ngrams {}".format(len(ngrams)), flush=True) # Build ngrams for the squad v2 dataset def process_task_squad(args, ngrams): # Build ngrams for the dataset of the given task def process_task(args, task_name, ngrams): print(' reading from {} and computing ngrams'.format('import datasets')) # using squad data from datasets print(" Current entities in ngrams {}".format(len(ngrams)), flush=True) # using validation/test data from datasets from datasets import load_dataset squad_v2 = load_dataset('squad_v2', split='validation') for line in squad_v2: entities_in_ngrams = len(ngrams) # load the dataset if task_name == 'squad': dataset = load_dataset('squad_v2', split='validation') elif task_name == 'natural_questions': dataset = load_dataset('natural_questions', split='validation') elif task_name == 'triviaqa': dataset = load_dataset('trivia_qa', 'unfiltered', split='test') elif task_name == 'webqa': dataset = load_dataset('web_questions', split='test') elif task_name == 'race': dataset = load_dataset('race', 'all', split='test') elif task_name == 'drop': dataset = load_dataset('drop', split='validation') elif task_name == 'coqa': dataset = load_dataset('coqa', split='validation') elif task_name == 'piqa': dataset = load_dataset('piqa', split='test') else: print("Invalid task name: {}".format(task_name), flush=True) return # read the dataset and add to ngrams for line in dataset: try: if task_name in ['squad', 'triviaqa', 'webqa', 'race', 'drop']: text = line['question'] compute_ngrams_insert_dict(args, text, ngrams) elif task_name == 'natural_questions': text = line['question']['text'] compute_ngrams_insert_dict(args, text, ngrams) elif task_name == 'coqa': all_questions = line['questions'] for question in all_questions: compute_ngrams_insert_dict(args, question, ngrams) elif task_name == 'piqa': text = line['goal'] compute_ngrams_insert_dict(args, text, ngrams) except Exception as e: print('Error:', e) print(" Entities in ngrams {}".format(len(ngrams)), flush=True) print(" After task {} entities in ngrams {}, added {}".format(task_name, \ len(ngrams), len(ngrams) - entities_in_ngrams), flush=True) if __name__ == '__main__': Loading @@ -227,7 +263,8 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--tasks', nargs = '*', required=True, default=None, \ help = 'Tasks to use for deduplication: currently ' ' suuport [lambada, squad]') ' suuport [lambada, squad, natural_questions,' ' triviaqa, webqa, race, drop, coqa, and piqa]') parser.add_argument('--lambada-path', type=str, default=None, help='Only Lambada task needs the path') parser.add_argument('--dedup-dataset', nargs = '*', default=None, Loading @@ -249,13 +286,16 @@ if __name__ == '__main__': # Build ngrams ngrams = {} start_time = time.time() for _, task_name in enumerate(args.tasks): print('Task: {}'.format(task_name), flush=True) if task_name == 'lambada': assert args.lambada_path is not None process_task_lambda(args, args.lambada_path, ngrams) if task_name == 'squad': process_task_squad(args, ngrams) else: process_task(args, task_name, ngrams) print(" Taken time {:.2f}".format(time.time() - start_time), flush=True) # get the range of the size of the ngrams ngrams_freq = {} Loading @@ -263,8 +303,8 @@ if __name__ == '__main__': length = len(ngram_key.split()) ngrams_freq[length] = ngrams_freq[length] + 1 if length in \ ngrams_freq else 1 ngrams_freq_sorted = sorted(ngrams_freq.items(), key=lambda item: item[1]) ngrams_freq_sorted = sorted(ngrams_freq.items(), key=lambda item: item[0]) print(" Ngram frequencies: {}".format(ngrams_freq_sorted), flush=True) print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\ len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\ Loading @@ -276,7 +316,10 @@ if __name__ == '__main__': counter = 0 start_time = time.time() if args.output is not None: out_f = open(args.output, 'wb') splitted, ignored, split_mt_thld, trimmed_count = 0, 0, 0, 0 assert len(args.dedup_dataset) == 2 Loading @@ -299,7 +342,7 @@ if __name__ == '__main__': trimmed_count += trimmed if len(text_buf_ngram_free) > 1: splitted += (len(text_buf_ngram_free) - 1) splitted += 1 if len(text_buf_ngram_free) == 0: ignored += 1 # more than 10 splits ignored Loading @@ -307,9 +350,10 @@ if __name__ == '__main__': text_buf_ngram_free = [] split_mt_thld += 1 if args.output is not None: for i in range(len(text_buf_ngram_free)): split_id_string = id_prefix + '-{:010d}'.format(int(counter)) \ + '-{:010d}'.format(int(i)) split_id_string = id_prefix + '-{:010d}'.format(int(\ counter)) + '-{:010d}'.format(int(i)) outjson = json.dumps({"text":text_buf_ngram_free[i], id_prefix+"_split_id":split_id_string}, ensure_ascii=False) Loading @@ -322,7 +366,9 @@ if __name__ == '__main__': except Exception as e: print('Error:', e) if args.output is not None: out_f.close() fin.close() print("Deduped file written to: {}".format(args.output), flush=True) Loading