Commit c44f7622 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Many more features added

parent 6013e23c
Loading
Loading
Loading
Loading
+184 −81
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@ from functools import partial
import json
import multiprocessing
import nltk
import pickle
import re
import string
import sys
@@ -61,11 +62,23 @@ def split_text(text, start_position, remove_char_each_side, seq):
    return text_first, text_second

def check_and_clean_text(args, words, ngrams, text, start_position, \
    text_buf_ngram_free, text_buf):
    text_buf_ngram_free, text_buf, local_ngram):

    seq = " ".join(words)
    if seq in ngrams:
        print(" [matched]: {}".format(seq), flush=True)
        #print(" [matched]: {}".format(seq), flush=True)

        if args.get_ngram_freq_only:
            # increase freq of this seq and then only consider the later part
            # of the text for further processing
            if seq in local_ngram:
                local_ngram[seq] += 1
            else:
                local_ngram[seq] = 1
            #print(" [increased]: {} {}".format(seq, ngrams[seq]), flush=True)
            if (start_position + len(seq) + 1) < len(text):
                text_buf.append(text[start_position + len(seq) + 1:len(text)])
            return False            

        # split the text
        text_first, text_second = split_text(text, start_position, \
@@ -84,6 +97,7 @@ def check_and_clean_text(args, words, ngrams, text, start_position, \
    # ngram free
    return True


def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
    # remove all the ngrams

@@ -95,6 +109,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
        text_buf = []

    text_buf_ngram_free = []
    local_ngram = {}
    while len(text_buf) > 0:

        # get the first one from the buffer
@@ -103,10 +118,10 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
        
        ngram_free = True
        # find each max n-grams and check dictionary
        for i in range(len(words) - args.ngram_size + 1):
        for i in range(len(words) - args.max_ngram_size + 1):
            check_ngram_free = check_and_clean_text(args, words[i:\
                i+args.ngram_size], ngrams, text, positions[i], \
                text_buf_ngram_free, text_buf)
                i+args.max_ngram_size], ngrams, text, positions[i], \
                text_buf_ngram_free, text_buf, local_ngram)

            # the seq is ngram free? if yes, break
            if not check_ngram_free:
@@ -118,7 +133,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
            for ngram_len, _ in ngrams_freq_sorted:
                check_ngram_free = check_and_clean_text(args, words[i:\
                    i+ngram_len], ngrams, text, positions[i], \
                    text_buf_ngram_free, text_buf)
                    text_buf_ngram_free, text_buf, local_ngram)

                # same check as above
                if not check_ngram_free:
@@ -130,16 +145,16 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
                break

        # for the last max n-gram, check all the lower ngrams in it
        if ngram_free and len(words) - args.ngram_size > 0:
        if ngram_free and len(words) - args.max_ngram_size > 0:
            # get the last words of the lax max ngram
            last_seq_words = words[(len(words) - args.ngram_size):len(words)]
            last_seq_start_position = len(words) - args.ngram_size
            last_seq_words = words[(len(words)-args.max_ngram_size):len(words)]
            last_seq_start_position = len(words) - args.max_ngram_size

            # check all n-grams lower than the max
            for pos, (ngram_len, _) in enumerate(ngrams_freq_sorted):

                # ignore the max ngram as has been considered already
                if ngram_len == args.ngram_size:
                if ngram_len == args.max_ngram_size:
                    continue

                # find each ngram of ngram_len in max n-grams and check
@@ -147,7 +162,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
                    check_ngram_free = check_and_clean_text(args, \
                        last_seq_words[i:i+ngram_len], ngrams, text,\
                        positions[last_seq_start_position+i], \
                        text_buf_ngram_free, text_buf)
                        text_buf_ngram_free, text_buf, local_ngram)

                    if not check_ngram_free:
                        ngram_free = False
@@ -157,34 +172,35 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
                    break

        # texts are ngram free
        if ngram_free:
        if ngram_free and not args.get_ngram_freq_only:
            text_buf_ngram_free.append(text)

    # check if the text has only been trimmed
    trimmed = 0
    if len(text_buf_ngram_free) == 1 and len(text_buf_ngram_free[0]) < \
        len(myjson[key]):
    if not args.get_ngram_freq_only and len(text_buf_ngram_free) == 1 and \
        len(text_buf_ngram_free[0]) < len(myjson[key]):
        trimmed = 1

    return text_buf_ngram_free, trimmed
    return text_buf_ngram_free, trimmed, local_ngram

# insert word sequence into dictionary
def insert_dict(words, ngrams, pos):
    seq = " ".join(words)
    if seq not in ngrams:
        ngrams[seq] = pos
        ngrams[seq] = 0
        #ngrams[seq] = pos

# insert each ngram from text into the ngrams dictionary
def compute_ngrams_insert_dict(args, text, ngrams):
    words, positions = get_words(text)
    if len(words) == 0:
    if len(words) < args.min_ngram_size:
        return

    if len(words) < args.ngram_size:
    if len(words) < args.max_ngram_size:
        insert_dict(words, ngrams, positions[0])

    for i in range(len(words) - args.ngram_size+1):
        insert_dict(words[i:i+args.ngram_size], ngrams, positions[i])
    for i in range(len(words) - args.max_ngram_size+1):
        insert_dict(words[i:i+args.max_ngram_size], ngrams, positions[i])


# Build ngrams for the lambada dataset
@@ -203,6 +219,7 @@ def process_task_lambda(args, task_file, ngrams):

# Build ngrams for the dataset of the given task
def process_task(args, task_name, ngrams):

    print(' reading from {} and computing ngrams'.format('import datasets'))
    print(" Current entities in ngrams {}".format(len(ngrams)), flush=True)
    # using validation/test data from datasets
@@ -253,39 +270,7 @@ def process_task(args, task_name, ngrams):
    print(" After task {} entities in ngrams {}, added {}".format(task_name, \
            len(ngrams), len(ngrams) - entities_in_ngrams), flush=True)

if __name__ == '__main__':

    # we use 13-grams, any text less than 200 characters got removed
    # any text splitted more than 10 got removed as well

    print('parsing the arguments ...')

    parser = argparse.ArgumentParser()
    parser.add_argument('--tasks', nargs = '*', required=True, default=None, \
                        help = 'Tasks to use for deduplication: currently '
                        ' suuport [lambada, squad, natural_questions,'
                        ' triviaqa, webqa, race, drop, coqa, and piqa]')
    parser.add_argument('--lambada-path', type=str, default=None,
                       help='Only Lambada task needs the path')
    parser.add_argument('--dedup-dataset', nargs = '*', default=None,
                       help='Dataset to deduplicate with the key to use'
                        ' e.g. cc.json text')
    parser.add_argument('--output', type=str, default=None,
                       help='Output file name to save dedup dataset')
    # Default dedup values
    parser.add_argument('--ngram-size', type=int, default=13,
                       help='Maximum size of ngram to use.')
    parser.add_argument('--filter-text-char-len', type=int, default=200,
                       help='Remove any text below this length.')
    parser.add_argument('--splits-count', type=int, default=10,
                       help='Remove any documents more than this many splits')
    parser.add_argument('--remove-char-each-side', type=int, default=200,
                       help='Maximum size of ngram to use.')

    args = parser.parse_args()

    # Build ngrams
    ngrams = {}
def compute_tasks_ngrams(args, ngrams):
    start_time = time.time()
    for _, task_name in enumerate(args.tasks):
        print('Task: {}'.format(task_name), flush=True)
@@ -294,10 +279,10 @@ if __name__ == '__main__':
            process_task_lambda(args, args.lambada_path, ngrams)
        else:
            process_task(args, task_name, ngrams)
    print(" Taken time to compute ngrams {:.2f}".format(time.time() - \
        start_time), flush=True)

    print(" Taken time {:.2f}".format(time.time() - start_time), flush=True)

    # get the range of the size of the ngrams
def compute_ngram_freq_sorted(args, ngrams):
    ngrams_freq = {}
    for ngram_key in ngrams.keys():
        length = len(ngram_key.split())
@@ -309,33 +294,74 @@ if __name__ == '__main__':
    print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\
            len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\
            ngrams_freq_sorted) -1 ][0]), flush=True)
    return ngrams_freq_sorted

    id_prefix = '-'.join(args.tasks[::2])
def get_ngrams_above_threshold(args, ngrams, ngrams_above_threshold, \
    dedup_file, dedup_key, ngrams_freq_sorted):

    print('Reading file {} and deduping n-grams'.format(args.dedup_dataset))
    start_time = time.time()
    # get the ngrams frequency
    args.get_ngram_freq_only = True
 
    # Open the large file to process in parallel
    num_workers = 40
    pool = multiprocessing.Pool(num_workers)
    fin = open(dedup_file, 'r', encoding='utf-8')
    free_ngram_abt_partial=partial(free_ngram, args=args, key=dedup_key, \
        ngrams=ngrams, ngrams_freq_sorted=ngrams_freq_sorted)
    free_ngrams_abt = pool.imap(free_ngram_abt_partial, fin, 500)
 
    counter = 0
    for _, _, local_ngram in free_ngrams_abt:
        counter += 1
        if counter % 1000 == 0:
            print(' [compute_stat]> processed {} documents in {:.2f} seconds ...'.
                    format(counter, time.time() - start_time), flush=True)
        for local_key in local_ngram:
            if local_key in ngrams:
                ngrams[local_key] += 1
        local_ngram = {}

    print(' Time taken to compute statistics {:.2f} seconds'.format(time.time() - \
        start_time), flush=True)
    pool.close()
    pool.join()

    start_time = time.time()
    counter_threshold = 0
    # Get ngram above theadhold
    for local_key, local_val in ngrams.items():
        if ngrams[local_key] > args.key_threshold:
            print(" [threshold] {} {}".format(local_key, local_val), flush=True)
            counter_threshold += 1
            ngrams_above_threshold[local_key] = 1
            
    print(' Ngrams above threshold {}'.format(counter_threshold), flush=True)
    fin.close()

    if args.output is not None:
        out_f = open(args.output, 'wb')
def clean_ngrams_above_threshold(args, ngrams_above_threshold, dedup_file, \
    dedup_key):

    splitted, ignored, split_mt_thld, trimmed_count = 0, 0, 0, 0
    start_time = time.time()
    # Now actually filter the dataset
    args.get_ngram_freq_only = False
    id_prefix = '-'.join(args.tasks[::2])

    assert len(args.dedup_dataset) == 2
    dedup_file = args.dedup_dataset[0]
    dedup_key = args.dedup_dataset[1]
    # get the range of the size of the ngrams
    ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams_above_threshold)

    # Setup multi-processing.
    # Open the large file to process in parallel
    num_workers = 40
    fin = open(dedup_file, 'r', encoding='utf-8')
    pool = multiprocessing.Pool(num_workers)
    free_ngram_x=partial(free_ngram, args=args, key=dedup_key, ngrams=ngrams, \
        ngrams_freq_sorted=ngrams_freq_sorted)
    fin = open(dedup_file, 'r', encoding='utf-8')
    free_ngram_clean_partial=partial(free_ngram, args=args, key=dedup_key, \
        ngrams=ngrams_above_threshold, ngrams_freq_sorted=ngrams_freq_sorted)
    free_ngrams_clean = pool.imap(free_ngram_clean_partial, fin, 500)
 
    free_ngrams = pool.imap(free_ngram_x, fin, 25)
    out_f = open(args.output, 'wb')

    for text_buf_ngram_free, trimmed in free_ngrams:
    counter = splitted = ignored = split_mt_thld = trimmed_count = 0
    for text_buf_ngram_free, trimmed, _ in free_ngrams_clean:
        counter += 1
        try:

@@ -361,18 +387,95 @@ if __name__ == '__main__':
                    out_f.write('\n'.encode('utf-8'))

            if counter % 1000 == 0:
                print(' [search]> processed {} documents in {:.2f} seconds ...'.
                print(' [final]> processed {} documents in {:.2f} seconds ...'.
                    format(counter, time.time() - start_time), flush=True)
        except Exception as e:
            print('Error:', e)

    if args.output is not None:
        out_f.close()
    print(' [final]> processed {} documents in {:.2f} seconds ...'.
        format(counter, time.time() - start_time), flush=True)
    
    print(' Total docs {} splitted {} ignored {} splits > theshold {} trimmed'\
        ' {}'.format(counter, splitted, ignored, split_mt_thld, trimmed_count)\
        , flush=True)

    pool.close()
    pool.join()

    out_f.close()
    fin.close()

    print("Deduped file written to: {}".format(args.output), flush=True)
    print("Total docs {} splitted {} ignored {} docs with many splits {}"\
        " trimmed {}".format(counter, splitted, ignored, split_mt_thld, \
        trimmed_count), flush=True)
if __name__ == '__main__':

    # we use 13-grams, any text less than 200 characters got removed
    # any text splitted more than 10 got removed as well

    print('parsing the arguments ...')

    parser = argparse.ArgumentParser()
    parser.add_argument('--tasks', nargs = '*', required=True, default=None, \
                        help = 'Tasks to use for deduplication: currently '
                        ' suuport [lambada, squad, natural_questions,'
                        ' triviaqa, webqa, race, drop, coqa, and piqa]')
    parser.add_argument('--lambada-path', type=str, default=None,
                       help='Only Lambada task needs the path')
    parser.add_argument('--dedup-dataset', nargs = '*', default=None,
                       help='Dataset to deduplicate with the key to use'
                        ' e.g. cc.json text')
    parser.add_argument('--output', type=str, default=None,
                       help='Output file name to save dedup dataset')
    # Default dedup values
    parser.add_argument('--max-ngram-size', type=int, default=13,
                       help='Maximum size of ngram to use.')
    parser.add_argument('--min-ngram-size', type=int, default=8,
                       help='Minimum size of ngram to use.')
    parser.add_argument('--filter-text-char-len', type=int, default=200,
                       help='Remove any text below this length.')
    parser.add_argument('--key-threshold', type=int, default=10,
                       help='Number of keys to consider as threshold')
    parser.add_argument('--save-dictionary', type=str, default=None,
                       help='Save the dictionary')
    parser.add_argument('--load-dictionary', type=str, default=None,
                       help='Load the dictionary')
    parser.add_argument('--splits-count', type=int, default=10,
                       help='Remove any documents more than this many splits')
    parser.add_argument('--remove-char-each-side', type=int, default=200,
                       help='Maximum size of ngram to use.')

    args = parser.parse_args()

    assert len(args.dedup_dataset) == 2
    dedup_file = args.dedup_dataset[0]
    dedup_key = args.dedup_dataset[1]

    # Setup multi-processing
    num_workers = 40
    if args.load_dictionary is None:

        # Build ngrams
        ngrams = {}
        compute_tasks_ngrams(args, ngrams)

        # get the range of the size of the ngrams
        ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams)

        # get ngram freq from large file in parallel
        # get ngrams above threshold
        ngrams_above_threshold = {}
        get_ngrams_above_threshold(args, ngrams, ngrams_above_threshold, \
            dedup_file, dedup_key, ngrams_freq_sorted)

        # save the dictionary if needed
        if args.save_dictionary is not None:
            with open(args.save_dictionary, 'wb') as save_dict_handle:
                pickle.dump(ngrams_above_threshold, save_dict_handle)
    else:
        with open(args.load_dictionary, 'rb') as load_dict_handle:
            ngrams_above_threshold = pickle.load(load_dict_handle)

    # filter the large file
    if args.output is not None:
        clean_ngrams_above_threshold(args, ngrams_above_threshold, \
            dedup_file, dedup_key)

    print('done :-)')