Loading tools/openwebtext/filter_ngrams.py +184 −81 Original line number Diff line number Diff line Loading @@ -24,6 +24,7 @@ from functools import partial import json import multiprocessing import nltk import pickle import re import string import sys Loading Loading @@ -61,11 +62,23 @@ def split_text(text, start_position, remove_char_each_side, seq): return text_first, text_second def check_and_clean_text(args, words, ngrams, text, start_position, \ text_buf_ngram_free, text_buf): text_buf_ngram_free, text_buf, local_ngram): seq = " ".join(words) if seq in ngrams: print(" [matched]: {}".format(seq), flush=True) #print(" [matched]: {}".format(seq), flush=True) if args.get_ngram_freq_only: # increase freq of this seq and then only consider the later part # of the text for further processing if seq in local_ngram: local_ngram[seq] += 1 else: local_ngram[seq] = 1 #print(" [increased]: {} {}".format(seq, ngrams[seq]), flush=True) if (start_position + len(seq) + 1) < len(text): text_buf.append(text[start_position + len(seq) + 1:len(text)]) return False # split the text text_first, text_second = split_text(text, start_position, \ Loading @@ -84,6 +97,7 @@ def check_and_clean_text(args, words, ngrams, text, start_position, \ # ngram free return True def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): # remove all the ngrams Loading @@ -95,6 +109,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): text_buf = [] text_buf_ngram_free = [] local_ngram = {} while len(text_buf) > 0: # get the first one from the buffer Loading @@ -103,10 +118,10 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): ngram_free = True # find each max n-grams and check dictionary for i in range(len(words) - args.ngram_size + 1): for i in range(len(words) - args.max_ngram_size + 1): check_ngram_free = check_and_clean_text(args, words[i:\ i+args.ngram_size], ngrams, text, positions[i], \ text_buf_ngram_free, text_buf) i+args.max_ngram_size], ngrams, text, positions[i], \ text_buf_ngram_free, text_buf, local_ngram) # the seq is ngram free? if yes, break if not check_ngram_free: Loading @@ -118,7 +133,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): for ngram_len, _ in ngrams_freq_sorted: check_ngram_free = check_and_clean_text(args, words[i:\ i+ngram_len], ngrams, text, positions[i], \ text_buf_ngram_free, text_buf) text_buf_ngram_free, text_buf, local_ngram) # same check as above if not check_ngram_free: Loading @@ -130,16 +145,16 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): break # for the last max n-gram, check all the lower ngrams in it if ngram_free and len(words) - args.ngram_size > 0: if ngram_free and len(words) - args.max_ngram_size > 0: # get the last words of the lax max ngram last_seq_words = words[(len(words) - args.ngram_size):len(words)] last_seq_start_position = len(words) - args.ngram_size last_seq_words = words[(len(words)-args.max_ngram_size):len(words)] last_seq_start_position = len(words) - args.max_ngram_size # check all n-grams lower than the max for pos, (ngram_len, _) in enumerate(ngrams_freq_sorted): # ignore the max ngram as has been considered already if ngram_len == args.ngram_size: if ngram_len == args.max_ngram_size: continue # find each ngram of ngram_len in max n-grams and check Loading @@ -147,7 +162,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): check_ngram_free = check_and_clean_text(args, \ last_seq_words[i:i+ngram_len], ngrams, text,\ positions[last_seq_start_position+i], \ text_buf_ngram_free, text_buf) text_buf_ngram_free, text_buf, local_ngram) if not check_ngram_free: ngram_free = False Loading @@ -157,34 +172,35 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): break # texts are ngram free if ngram_free: if ngram_free and not args.get_ngram_freq_only: text_buf_ngram_free.append(text) # check if the text has only been trimmed trimmed = 0 if len(text_buf_ngram_free) == 1 and len(text_buf_ngram_free[0]) < \ len(myjson[key]): if not args.get_ngram_freq_only and len(text_buf_ngram_free) == 1 and \ len(text_buf_ngram_free[0]) < len(myjson[key]): trimmed = 1 return text_buf_ngram_free, trimmed return text_buf_ngram_free, trimmed, local_ngram # insert word sequence into dictionary def insert_dict(words, ngrams, pos): seq = " ".join(words) if seq not in ngrams: ngrams[seq] = pos ngrams[seq] = 0 #ngrams[seq] = pos # insert each ngram from text into the ngrams dictionary def compute_ngrams_insert_dict(args, text, ngrams): words, positions = get_words(text) if len(words) == 0: if len(words) < args.min_ngram_size: return if len(words) < args.ngram_size: if len(words) < args.max_ngram_size: insert_dict(words, ngrams, positions[0]) for i in range(len(words) - args.ngram_size+1): insert_dict(words[i:i+args.ngram_size], ngrams, positions[i]) for i in range(len(words) - args.max_ngram_size+1): insert_dict(words[i:i+args.max_ngram_size], ngrams, positions[i]) # Build ngrams for the lambada dataset Loading @@ -203,6 +219,7 @@ def process_task_lambda(args, task_file, ngrams): # Build ngrams for the dataset of the given task def process_task(args, task_name, ngrams): print(' reading from {} and computing ngrams'.format('import datasets')) print(" Current entities in ngrams {}".format(len(ngrams)), flush=True) # using validation/test data from datasets Loading Loading @@ -253,39 +270,7 @@ def process_task(args, task_name, ngrams): print(" After task {} entities in ngrams {}, added {}".format(task_name, \ len(ngrams), len(ngrams) - entities_in_ngrams), flush=True) if __name__ == '__main__': # we use 13-grams, any text less than 200 characters got removed # any text splitted more than 10 got removed as well print('parsing the arguments ...') parser = argparse.ArgumentParser() parser.add_argument('--tasks', nargs = '*', required=True, default=None, \ help = 'Tasks to use for deduplication: currently ' ' suuport [lambada, squad, natural_questions,' ' triviaqa, webqa, race, drop, coqa, and piqa]') parser.add_argument('--lambada-path', type=str, default=None, help='Only Lambada task needs the path') parser.add_argument('--dedup-dataset', nargs = '*', default=None, help='Dataset to deduplicate with the key to use' ' e.g. cc.json text') parser.add_argument('--output', type=str, default=None, help='Output file name to save dedup dataset') # Default dedup values parser.add_argument('--ngram-size', type=int, default=13, help='Maximum size of ngram to use.') parser.add_argument('--filter-text-char-len', type=int, default=200, help='Remove any text below this length.') parser.add_argument('--splits-count', type=int, default=10, help='Remove any documents more than this many splits') parser.add_argument('--remove-char-each-side', type=int, default=200, help='Maximum size of ngram to use.') args = parser.parse_args() # Build ngrams ngrams = {} def compute_tasks_ngrams(args, ngrams): start_time = time.time() for _, task_name in enumerate(args.tasks): print('Task: {}'.format(task_name), flush=True) Loading @@ -294,10 +279,10 @@ if __name__ == '__main__': process_task_lambda(args, args.lambada_path, ngrams) else: process_task(args, task_name, ngrams) print(" Taken time to compute ngrams {:.2f}".format(time.time() - \ start_time), flush=True) print(" Taken time {:.2f}".format(time.time() - start_time), flush=True) # get the range of the size of the ngrams def compute_ngram_freq_sorted(args, ngrams): ngrams_freq = {} for ngram_key in ngrams.keys(): length = len(ngram_key.split()) Loading @@ -309,33 +294,74 @@ if __name__ == '__main__': print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\ len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\ ngrams_freq_sorted) -1 ][0]), flush=True) return ngrams_freq_sorted id_prefix = '-'.join(args.tasks[::2]) def get_ngrams_above_threshold(args, ngrams, ngrams_above_threshold, \ dedup_file, dedup_key, ngrams_freq_sorted): print('Reading file {} and deduping n-grams'.format(args.dedup_dataset)) start_time = time.time() # get the ngrams frequency args.get_ngram_freq_only = True # Open the large file to process in parallel num_workers = 40 pool = multiprocessing.Pool(num_workers) fin = open(dedup_file, 'r', encoding='utf-8') free_ngram_abt_partial=partial(free_ngram, args=args, key=dedup_key, \ ngrams=ngrams, ngrams_freq_sorted=ngrams_freq_sorted) free_ngrams_abt = pool.imap(free_ngram_abt_partial, fin, 500) counter = 0 for _, _, local_ngram in free_ngrams_abt: counter += 1 if counter % 1000 == 0: print(' [compute_stat]> processed {} documents in {:.2f} seconds ...'. format(counter, time.time() - start_time), flush=True) for local_key in local_ngram: if local_key in ngrams: ngrams[local_key] += 1 local_ngram = {} print(' Time taken to compute statistics {:.2f} seconds'.format(time.time() - \ start_time), flush=True) pool.close() pool.join() start_time = time.time() counter_threshold = 0 # Get ngram above theadhold for local_key, local_val in ngrams.items(): if ngrams[local_key] > args.key_threshold: print(" [threshold] {} {}".format(local_key, local_val), flush=True) counter_threshold += 1 ngrams_above_threshold[local_key] = 1 print(' Ngrams above threshold {}'.format(counter_threshold), flush=True) fin.close() if args.output is not None: out_f = open(args.output, 'wb') def clean_ngrams_above_threshold(args, ngrams_above_threshold, dedup_file, \ dedup_key): splitted, ignored, split_mt_thld, trimmed_count = 0, 0, 0, 0 start_time = time.time() # Now actually filter the dataset args.get_ngram_freq_only = False id_prefix = '-'.join(args.tasks[::2]) assert len(args.dedup_dataset) == 2 dedup_file = args.dedup_dataset[0] dedup_key = args.dedup_dataset[1] # get the range of the size of the ngrams ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams_above_threshold) # Setup multi-processing. # Open the large file to process in parallel num_workers = 40 fin = open(dedup_file, 'r', encoding='utf-8') pool = multiprocessing.Pool(num_workers) free_ngram_x=partial(free_ngram, args=args, key=dedup_key, ngrams=ngrams, \ ngrams_freq_sorted=ngrams_freq_sorted) fin = open(dedup_file, 'r', encoding='utf-8') free_ngram_clean_partial=partial(free_ngram, args=args, key=dedup_key, \ ngrams=ngrams_above_threshold, ngrams_freq_sorted=ngrams_freq_sorted) free_ngrams_clean = pool.imap(free_ngram_clean_partial, fin, 500) free_ngrams = pool.imap(free_ngram_x, fin, 25) out_f = open(args.output, 'wb') for text_buf_ngram_free, trimmed in free_ngrams: counter = splitted = ignored = split_mt_thld = trimmed_count = 0 for text_buf_ngram_free, trimmed, _ in free_ngrams_clean: counter += 1 try: Loading @@ -361,18 +387,95 @@ if __name__ == '__main__': out_f.write('\n'.encode('utf-8')) if counter % 1000 == 0: print(' [search]> processed {} documents in {:.2f} seconds ...'. print(' [final]> processed {} documents in {:.2f} seconds ...'. format(counter, time.time() - start_time), flush=True) except Exception as e: print('Error:', e) if args.output is not None: out_f.close() print(' [final]> processed {} documents in {:.2f} seconds ...'. format(counter, time.time() - start_time), flush=True) print(' Total docs {} splitted {} ignored {} splits > theshold {} trimmed'\ ' {}'.format(counter, splitted, ignored, split_mt_thld, trimmed_count)\ , flush=True) pool.close() pool.join() out_f.close() fin.close() print("Deduped file written to: {}".format(args.output), flush=True) print("Total docs {} splitted {} ignored {} docs with many splits {}"\ " trimmed {}".format(counter, splitted, ignored, split_mt_thld, \ trimmed_count), flush=True) if __name__ == '__main__': # we use 13-grams, any text less than 200 characters got removed # any text splitted more than 10 got removed as well print('parsing the arguments ...') parser = argparse.ArgumentParser() parser.add_argument('--tasks', nargs = '*', required=True, default=None, \ help = 'Tasks to use for deduplication: currently ' ' suuport [lambada, squad, natural_questions,' ' triviaqa, webqa, race, drop, coqa, and piqa]') parser.add_argument('--lambada-path', type=str, default=None, help='Only Lambada task needs the path') parser.add_argument('--dedup-dataset', nargs = '*', default=None, help='Dataset to deduplicate with the key to use' ' e.g. cc.json text') parser.add_argument('--output', type=str, default=None, help='Output file name to save dedup dataset') # Default dedup values parser.add_argument('--max-ngram-size', type=int, default=13, help='Maximum size of ngram to use.') parser.add_argument('--min-ngram-size', type=int, default=8, help='Minimum size of ngram to use.') parser.add_argument('--filter-text-char-len', type=int, default=200, help='Remove any text below this length.') parser.add_argument('--key-threshold', type=int, default=10, help='Number of keys to consider as threshold') parser.add_argument('--save-dictionary', type=str, default=None, help='Save the dictionary') parser.add_argument('--load-dictionary', type=str, default=None, help='Load the dictionary') parser.add_argument('--splits-count', type=int, default=10, help='Remove any documents more than this many splits') parser.add_argument('--remove-char-each-side', type=int, default=200, help='Maximum size of ngram to use.') args = parser.parse_args() assert len(args.dedup_dataset) == 2 dedup_file = args.dedup_dataset[0] dedup_key = args.dedup_dataset[1] # Setup multi-processing num_workers = 40 if args.load_dictionary is None: # Build ngrams ngrams = {} compute_tasks_ngrams(args, ngrams) # get the range of the size of the ngrams ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams) # get ngram freq from large file in parallel # get ngrams above threshold ngrams_above_threshold = {} get_ngrams_above_threshold(args, ngrams, ngrams_above_threshold, \ dedup_file, dedup_key, ngrams_freq_sorted) # save the dictionary if needed if args.save_dictionary is not None: with open(args.save_dictionary, 'wb') as save_dict_handle: pickle.dump(ngrams_above_threshold, save_dict_handle) else: with open(args.load_dictionary, 'rb') as load_dict_handle: ngrams_above_threshold = pickle.load(load_dict_handle) # filter the large file if args.output is not None: clean_ngrams_above_threshold(args, ngrams_above_threshold, \ dedup_file, dedup_key) print('done :-)') Loading
tools/openwebtext/filter_ngrams.py +184 −81 Original line number Diff line number Diff line Loading @@ -24,6 +24,7 @@ from functools import partial import json import multiprocessing import nltk import pickle import re import string import sys Loading Loading @@ -61,11 +62,23 @@ def split_text(text, start_position, remove_char_each_side, seq): return text_first, text_second def check_and_clean_text(args, words, ngrams, text, start_position, \ text_buf_ngram_free, text_buf): text_buf_ngram_free, text_buf, local_ngram): seq = " ".join(words) if seq in ngrams: print(" [matched]: {}".format(seq), flush=True) #print(" [matched]: {}".format(seq), flush=True) if args.get_ngram_freq_only: # increase freq of this seq and then only consider the later part # of the text for further processing if seq in local_ngram: local_ngram[seq] += 1 else: local_ngram[seq] = 1 #print(" [increased]: {} {}".format(seq, ngrams[seq]), flush=True) if (start_position + len(seq) + 1) < len(text): text_buf.append(text[start_position + len(seq) + 1:len(text)]) return False # split the text text_first, text_second = split_text(text, start_position, \ Loading @@ -84,6 +97,7 @@ def check_and_clean_text(args, words, ngrams, text, start_position, \ # ngram free return True def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): # remove all the ngrams Loading @@ -95,6 +109,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): text_buf = [] text_buf_ngram_free = [] local_ngram = {} while len(text_buf) > 0: # get the first one from the buffer Loading @@ -103,10 +118,10 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): ngram_free = True # find each max n-grams and check dictionary for i in range(len(words) - args.ngram_size + 1): for i in range(len(words) - args.max_ngram_size + 1): check_ngram_free = check_and_clean_text(args, words[i:\ i+args.ngram_size], ngrams, text, positions[i], \ text_buf_ngram_free, text_buf) i+args.max_ngram_size], ngrams, text, positions[i], \ text_buf_ngram_free, text_buf, local_ngram) # the seq is ngram free? if yes, break if not check_ngram_free: Loading @@ -118,7 +133,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): for ngram_len, _ in ngrams_freq_sorted: check_ngram_free = check_and_clean_text(args, words[i:\ i+ngram_len], ngrams, text, positions[i], \ text_buf_ngram_free, text_buf) text_buf_ngram_free, text_buf, local_ngram) # same check as above if not check_ngram_free: Loading @@ -130,16 +145,16 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): break # for the last max n-gram, check all the lower ngrams in it if ngram_free and len(words) - args.ngram_size > 0: if ngram_free and len(words) - args.max_ngram_size > 0: # get the last words of the lax max ngram last_seq_words = words[(len(words) - args.ngram_size):len(words)] last_seq_start_position = len(words) - args.ngram_size last_seq_words = words[(len(words)-args.max_ngram_size):len(words)] last_seq_start_position = len(words) - args.max_ngram_size # check all n-grams lower than the max for pos, (ngram_len, _) in enumerate(ngrams_freq_sorted): # ignore the max ngram as has been considered already if ngram_len == args.ngram_size: if ngram_len == args.max_ngram_size: continue # find each ngram of ngram_len in max n-grams and check Loading @@ -147,7 +162,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): check_ngram_free = check_and_clean_text(args, \ last_seq_words[i:i+ngram_len], ngrams, text,\ positions[last_seq_start_position+i], \ text_buf_ngram_free, text_buf) text_buf_ngram_free, text_buf, local_ngram) if not check_ngram_free: ngram_free = False Loading @@ -157,34 +172,35 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): break # texts are ngram free if ngram_free: if ngram_free and not args.get_ngram_freq_only: text_buf_ngram_free.append(text) # check if the text has only been trimmed trimmed = 0 if len(text_buf_ngram_free) == 1 and len(text_buf_ngram_free[0]) < \ len(myjson[key]): if not args.get_ngram_freq_only and len(text_buf_ngram_free) == 1 and \ len(text_buf_ngram_free[0]) < len(myjson[key]): trimmed = 1 return text_buf_ngram_free, trimmed return text_buf_ngram_free, trimmed, local_ngram # insert word sequence into dictionary def insert_dict(words, ngrams, pos): seq = " ".join(words) if seq not in ngrams: ngrams[seq] = pos ngrams[seq] = 0 #ngrams[seq] = pos # insert each ngram from text into the ngrams dictionary def compute_ngrams_insert_dict(args, text, ngrams): words, positions = get_words(text) if len(words) == 0: if len(words) < args.min_ngram_size: return if len(words) < args.ngram_size: if len(words) < args.max_ngram_size: insert_dict(words, ngrams, positions[0]) for i in range(len(words) - args.ngram_size+1): insert_dict(words[i:i+args.ngram_size], ngrams, positions[i]) for i in range(len(words) - args.max_ngram_size+1): insert_dict(words[i:i+args.max_ngram_size], ngrams, positions[i]) # Build ngrams for the lambada dataset Loading @@ -203,6 +219,7 @@ def process_task_lambda(args, task_file, ngrams): # Build ngrams for the dataset of the given task def process_task(args, task_name, ngrams): print(' reading from {} and computing ngrams'.format('import datasets')) print(" Current entities in ngrams {}".format(len(ngrams)), flush=True) # using validation/test data from datasets Loading Loading @@ -253,39 +270,7 @@ def process_task(args, task_name, ngrams): print(" After task {} entities in ngrams {}, added {}".format(task_name, \ len(ngrams), len(ngrams) - entities_in_ngrams), flush=True) if __name__ == '__main__': # we use 13-grams, any text less than 200 characters got removed # any text splitted more than 10 got removed as well print('parsing the arguments ...') parser = argparse.ArgumentParser() parser.add_argument('--tasks', nargs = '*', required=True, default=None, \ help = 'Tasks to use for deduplication: currently ' ' suuport [lambada, squad, natural_questions,' ' triviaqa, webqa, race, drop, coqa, and piqa]') parser.add_argument('--lambada-path', type=str, default=None, help='Only Lambada task needs the path') parser.add_argument('--dedup-dataset', nargs = '*', default=None, help='Dataset to deduplicate with the key to use' ' e.g. cc.json text') parser.add_argument('--output', type=str, default=None, help='Output file name to save dedup dataset') # Default dedup values parser.add_argument('--ngram-size', type=int, default=13, help='Maximum size of ngram to use.') parser.add_argument('--filter-text-char-len', type=int, default=200, help='Remove any text below this length.') parser.add_argument('--splits-count', type=int, default=10, help='Remove any documents more than this many splits') parser.add_argument('--remove-char-each-side', type=int, default=200, help='Maximum size of ngram to use.') args = parser.parse_args() # Build ngrams ngrams = {} def compute_tasks_ngrams(args, ngrams): start_time = time.time() for _, task_name in enumerate(args.tasks): print('Task: {}'.format(task_name), flush=True) Loading @@ -294,10 +279,10 @@ if __name__ == '__main__': process_task_lambda(args, args.lambada_path, ngrams) else: process_task(args, task_name, ngrams) print(" Taken time to compute ngrams {:.2f}".format(time.time() - \ start_time), flush=True) print(" Taken time {:.2f}".format(time.time() - start_time), flush=True) # get the range of the size of the ngrams def compute_ngram_freq_sorted(args, ngrams): ngrams_freq = {} for ngram_key in ngrams.keys(): length = len(ngram_key.split()) Loading @@ -309,33 +294,74 @@ if __name__ == '__main__': print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\ len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\ ngrams_freq_sorted) -1 ][0]), flush=True) return ngrams_freq_sorted id_prefix = '-'.join(args.tasks[::2]) def get_ngrams_above_threshold(args, ngrams, ngrams_above_threshold, \ dedup_file, dedup_key, ngrams_freq_sorted): print('Reading file {} and deduping n-grams'.format(args.dedup_dataset)) start_time = time.time() # get the ngrams frequency args.get_ngram_freq_only = True # Open the large file to process in parallel num_workers = 40 pool = multiprocessing.Pool(num_workers) fin = open(dedup_file, 'r', encoding='utf-8') free_ngram_abt_partial=partial(free_ngram, args=args, key=dedup_key, \ ngrams=ngrams, ngrams_freq_sorted=ngrams_freq_sorted) free_ngrams_abt = pool.imap(free_ngram_abt_partial, fin, 500) counter = 0 for _, _, local_ngram in free_ngrams_abt: counter += 1 if counter % 1000 == 0: print(' [compute_stat]> processed {} documents in {:.2f} seconds ...'. format(counter, time.time() - start_time), flush=True) for local_key in local_ngram: if local_key in ngrams: ngrams[local_key] += 1 local_ngram = {} print(' Time taken to compute statistics {:.2f} seconds'.format(time.time() - \ start_time), flush=True) pool.close() pool.join() start_time = time.time() counter_threshold = 0 # Get ngram above theadhold for local_key, local_val in ngrams.items(): if ngrams[local_key] > args.key_threshold: print(" [threshold] {} {}".format(local_key, local_val), flush=True) counter_threshold += 1 ngrams_above_threshold[local_key] = 1 print(' Ngrams above threshold {}'.format(counter_threshold), flush=True) fin.close() if args.output is not None: out_f = open(args.output, 'wb') def clean_ngrams_above_threshold(args, ngrams_above_threshold, dedup_file, \ dedup_key): splitted, ignored, split_mt_thld, trimmed_count = 0, 0, 0, 0 start_time = time.time() # Now actually filter the dataset args.get_ngram_freq_only = False id_prefix = '-'.join(args.tasks[::2]) assert len(args.dedup_dataset) == 2 dedup_file = args.dedup_dataset[0] dedup_key = args.dedup_dataset[1] # get the range of the size of the ngrams ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams_above_threshold) # Setup multi-processing. # Open the large file to process in parallel num_workers = 40 fin = open(dedup_file, 'r', encoding='utf-8') pool = multiprocessing.Pool(num_workers) free_ngram_x=partial(free_ngram, args=args, key=dedup_key, ngrams=ngrams, \ ngrams_freq_sorted=ngrams_freq_sorted) fin = open(dedup_file, 'r', encoding='utf-8') free_ngram_clean_partial=partial(free_ngram, args=args, key=dedup_key, \ ngrams=ngrams_above_threshold, ngrams_freq_sorted=ngrams_freq_sorted) free_ngrams_clean = pool.imap(free_ngram_clean_partial, fin, 500) free_ngrams = pool.imap(free_ngram_x, fin, 25) out_f = open(args.output, 'wb') for text_buf_ngram_free, trimmed in free_ngrams: counter = splitted = ignored = split_mt_thld = trimmed_count = 0 for text_buf_ngram_free, trimmed, _ in free_ngrams_clean: counter += 1 try: Loading @@ -361,18 +387,95 @@ if __name__ == '__main__': out_f.write('\n'.encode('utf-8')) if counter % 1000 == 0: print(' [search]> processed {} documents in {:.2f} seconds ...'. print(' [final]> processed {} documents in {:.2f} seconds ...'. format(counter, time.time() - start_time), flush=True) except Exception as e: print('Error:', e) if args.output is not None: out_f.close() print(' [final]> processed {} documents in {:.2f} seconds ...'. format(counter, time.time() - start_time), flush=True) print(' Total docs {} splitted {} ignored {} splits > theshold {} trimmed'\ ' {}'.format(counter, splitted, ignored, split_mt_thld, trimmed_count)\ , flush=True) pool.close() pool.join() out_f.close() fin.close() print("Deduped file written to: {}".format(args.output), flush=True) print("Total docs {} splitted {} ignored {} docs with many splits {}"\ " trimmed {}".format(counter, splitted, ignored, split_mt_thld, \ trimmed_count), flush=True) if __name__ == '__main__': # we use 13-grams, any text less than 200 characters got removed # any text splitted more than 10 got removed as well print('parsing the arguments ...') parser = argparse.ArgumentParser() parser.add_argument('--tasks', nargs = '*', required=True, default=None, \ help = 'Tasks to use for deduplication: currently ' ' suuport [lambada, squad, natural_questions,' ' triviaqa, webqa, race, drop, coqa, and piqa]') parser.add_argument('--lambada-path', type=str, default=None, help='Only Lambada task needs the path') parser.add_argument('--dedup-dataset', nargs = '*', default=None, help='Dataset to deduplicate with the key to use' ' e.g. cc.json text') parser.add_argument('--output', type=str, default=None, help='Output file name to save dedup dataset') # Default dedup values parser.add_argument('--max-ngram-size', type=int, default=13, help='Maximum size of ngram to use.') parser.add_argument('--min-ngram-size', type=int, default=8, help='Minimum size of ngram to use.') parser.add_argument('--filter-text-char-len', type=int, default=200, help='Remove any text below this length.') parser.add_argument('--key-threshold', type=int, default=10, help='Number of keys to consider as threshold') parser.add_argument('--save-dictionary', type=str, default=None, help='Save the dictionary') parser.add_argument('--load-dictionary', type=str, default=None, help='Load the dictionary') parser.add_argument('--splits-count', type=int, default=10, help='Remove any documents more than this many splits') parser.add_argument('--remove-char-each-side', type=int, default=200, help='Maximum size of ngram to use.') args = parser.parse_args() assert len(args.dedup_dataset) == 2 dedup_file = args.dedup_dataset[0] dedup_key = args.dedup_dataset[1] # Setup multi-processing num_workers = 40 if args.load_dictionary is None: # Build ngrams ngrams = {} compute_tasks_ngrams(args, ngrams) # get the range of the size of the ngrams ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams) # get ngram freq from large file in parallel # get ngrams above threshold ngrams_above_threshold = {} get_ngrams_above_threshold(args, ngrams, ngrams_above_threshold, \ dedup_file, dedup_key, ngrams_freq_sorted) # save the dictionary if needed if args.save_dictionary is not None: with open(args.save_dictionary, 'wb') as save_dict_handle: pickle.dump(ngrams_above_threshold, save_dict_handle) else: with open(args.load_dictionary, 'rb') as load_dict_handle: ngrams_above_threshold = pickle.load(load_dict_handle) # filter the large file if args.output is not None: clean_ngrams_above_threshold(args, ngrams_above_threshold, \ dedup_file, dedup_key) print('done :-)')