Commit 7a5768ac authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main_dedup' into 'main'

Training data and task deduplication

See merge request ADLR/megatron-lm!252
parents f32a638d 0fa728ac
Loading
Loading
Loading
Loading
+8 −4
Original line number Diff line number Diff line
@@ -26,7 +26,8 @@ python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for
```
python cleanup_dataset.py <input data file> <output cleaned data filename>
```
2. Using LSH, find possible duplicates and store then in a file for later processing. This step can NOT be sharded and usually takes 12 to 24 hours for OpenWebText dataset. The code supports saving and loading fingerprints for recurrent deduplications.
Additional cleanup (e.g. remove documents less than 512 characters or dataset specific cleaning like stories, realnews datasets) can be done using `cleanup_fix_dataset.py`. More details can be found by running `python cleanup_fix_dataset.py --help`.
2. Using LSH, find possible duplicates and store then in a file for later processing. The code supports saving and loading fingerprints for recurrent deduplications, and is also multithreaded for faster processing. More details are can be found by `python find_duplicate.py --help`.
```
python find_duplicates.py --inputs <pairlist list of input cleaned data files and keys, e.g. cc.json cc_id news.json news_id> --output <output possible duplicate urls filename>
```
@@ -46,10 +47,13 @@ shuf <cleaned deduped data file> -o train_data.json

# Deduplicating ngrams

To deduplicate the downstream tasks from the training dataset, we run the following command.
To deduplicate the downstream tasks (e.g. lambada, squad) from the training dataset, we run the following command.

```
python filter_ngrams.py <down stream task dataset> <training dataset to deduplicate> <output training dataset>
python filter_ngrams.py --tasks <name of he task, e.g. lambada, squad> --dedup-dataset <training dataset to deduplicate> <json key> --output <output training dataset>
```
We use 13-grams by default for the deduplication. When we find a 13-gram match in a training document, we split the document into two pieces and remove the 13-gram along with 200 characters from the both side of the 13-gram. We also remove any splitted document with less than 200 characters or if a document got splitted more than 10 times. These parameters can be changed using corresponding arguments.

We use 13-grams for the deduplication. When we find a 13-gram match in a training document, we split the document into two pieces and remove the 13-gram along with 200 characters from the both side of the 13-gram. We also remove any splitted document with less than 200 characters or if a document got splitted more than 10 times.
Only for the lambada task, we need to provide the path, `--lambada-path <path of the lambada test data>`.

Several other features (e.g. save and load dictionary) have been added, look at `python filter_ngrams.py --help` for details.
+67 −0
Original line number Diff line number Diff line
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import os
import time

"""
This code adds id to each json object in a json file. User can add prefix
to the ids.
"""

if __name__ == '__main__':

    print('parsing the arguments ...')

    parser = argparse.ArgumentParser()
    parser.add_argument('--input-file', type=str, default=None, help='Input'\
        ' json file where id needs to be added')
    parser.add_argument('--output-file', type=str, default=None, help=\
        'Output file name with id')
    parser.add_argument('--id-prefix', type=str, default=None, help=\
        'Id prefix')
    parser.add_argument('--log-interval', type=int, default=100,
                       help='Log interval')
    args = parser.parse_args()

    print('Adding ids to dataset ...')

    f_input = open(args.input_file, 'r', encoding='utf-8')
    f_output = open(args.output_file, 'wb')

    unique_ids = 1
    start_time = time.time()
    for row in f_input:
        each_row = json.loads(row)
        adlr_id_string = args.id_prefix + '-{:010d}'.format(int(unique_ids))
        each_row['adlr_id'] = adlr_id_string
        myjson = json.dumps(each_row, ensure_ascii=False)

        f_output.write(myjson.encode('utf-8'))
        f_output.write('\n'.encode('utf-8'))

        if unique_ids % args.log_interval == 0:
            print('    processed {:9d} documents in {:.2f} seconds ...'.format( \
                    unique_ids, time.time() - start_time), flush=True)

        unique_ids += 1

    # Close the file.
    f_input.close()
    f_output.close()
    
    print('done :-)', flush=True)
+191 −0
Original line number Diff line number Diff line
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Filter and clean documents:
Capable to clean docs with less than 512 characters, less than
256 characters and contains javascript, fix text and dataset specific
cleaning like stories and realnews datasets.
Program arguments have the details.
"""

import argparse
from functools import partial
import glob
import ftfy
import json
from langdetect import detect
import multiprocessing
import os
from pathlib import Path
import re
import time

def process_doc(json_line, args):

    # Read the line.
    document = json.loads(json_line)
    text = document['text']

    output = {'remove_512': False, 'remove_256_javascript': False, \
        'remove_512_non_english': False, 'ftfy_fix_text': False, \
        'general_cleaning': False}

    try:
        # Reomove all docs with less than 512 characters
        if "remove_512" in args.tasks:
            if len(text) < 512:
                output['remove_512'] = True
                return output, text, document, True

        # Remove docs if less than 256 character length and contains Javascript
        if "remove_256_javascript" in args.tasks:
            if len(text) < 256 and 'javascript' in text.lower():
                output['remove_256_javascript'] = True
                return output, text, document, True

        # Remove docs < 512 and nonenglish
        if "remove_512_non_english" in args.tasks:
            if len(text) < 512 and detect(text) != 'en':
                output['remove_512_non_english'] = True
                return output, text, document, True

        # Fix the text using ftfy, don't remove the text, hence return False
        if "ftfy_fix_text" in args.tasks:
            fixed_text = ftfy.fix_text(text)
            output['ftfy_fix_text'] = True
            return output, fixed_text, document, False

        # Cleaning extra spaces and newlines
        if "general_cleaning" in args.tasks:
            cleaned_text = re.sub(r"  +|\b\n+ |\b\n+", " ", text)
            #cleaned_text = re.sub(r"\n\n+", "\n\n", text) # used this for Gutenberg dataset
            #cleaned_text = re.sub(r"\n", "\n\n", text) # Used this for realnews

            # stories datasets
            #cleaned_text = re.sub(r" \'", "'", text)
            #cleaned_text = re.sub(r" \!", "!", cleaned_text)
            #cleaned_text = re.sub(r" \.", ".", cleaned_text)
            #cleaned_text = re.sub(r" \?", "?", cleaned_text)
            #cleaned_text = re.sub(r" - ", "-", cleaned_text)
            ##cleaned_text = re.sub(r"\" ", "\"", cleaned_text)
            #cleaned_text = re.sub(r" @ ", "@", cleaned_text)

            output['general_cleaning'] = True
            return output, cleaned_text, document, False

    except Exception as e:
        print('Error: *************************\n{}\ntext: {}'.format(e, \
            text), flush=True)
        return output, text, document, True

    # don't remove
    return output, text, document, False


def process_set(args, input_file, output_f_cleaned, output_f_filtered):

    print(' > working on {} ...'.format(input_file), flush=True)
    
    num_docs = num_remove_512 = num_remove_java = num_remove_512_non_english \
        = num_ftfy_fix_text = num_general_cleaning = 0

    # Output file and counters.
    output_cleaned = open(output_f_cleaned, 'wb')
    output_filtered = open(output_f_filtered, 'wb')

    start_time = time.time()

    # Setup multi-processing.
    num_workers = 40
    fin = open(input_file, 'r', encoding='utf-8')
    pool = multiprocessing.Pool(num_workers)
    process_doc_partial = partial(process_doc, args=args)
    processed_docs = pool.imap(process_doc_partial, fin, 500)

    # Process documents.
    for output, text, document, to_filter in processed_docs:
        num_docs += 1

        num_remove_512 += 1 if output['remove_512'] else 0
        num_remove_java += 1 if output['remove_256_javascript'] else 0
        num_remove_512_non_english += 1 if output['remove_512_non_english'] \
            else 0
        num_ftfy_fix_text += 1 if output['ftfy_fix_text'] else 0
        num_general_cleaning += 1 if output['general_cleaning'] else 0

        document['text'] = text
        myjson = json.dumps(document, ensure_ascii=False)

        if to_filter:
            output_filtered.write(myjson.encode('utf-8'))
            output_filtered.write('\n'.encode('utf-8'))
        else:
            output_cleaned.write(myjson.encode('utf-8'))
            output_cleaned.write('\n'.encode('utf-8'))

        if num_docs % args.log_interval == 0:
            print('    processed {:9d} documents in {:.2f} seconds ...'.format(
                num_docs, time.time() - start_time), flush=True)

    # Close the file.
    output_cleaned.close()
    output_filtered.close()
    fin.close()

    # Print stats.
    print('  >> total docs: {} remove_512 {} remove_256_javascript {} '\
        'remove_512_non_english {} ftfy_fix_text {} general_cleaning {}'.\
        format(num_docs, num_remove_512, num_remove_java,\
        num_remove_512_non_english, num_ftfy_fix_text, \
        num_general_cleaning), flush=True)

if __name__ == '__main__':


    print('parsing the arguments ...')

    parser = argparse.ArgumentParser()
    parser.add_argument('--input-files', nargs = '*', required=True, default=\
                        None, help = 'Input json files that needs to be'\
                        ' cleaned')
    parser.add_argument('--tasks', nargs = '*', required=True, default=None,\
                        help = 'Tasks to perform on the input files, ' \
                        'such as remove_512, remove_256_javascript, ' \
                        'remove_512_non_english, ftfy_fix_text, and ' \
                        'general_cleaning. 256 or 512 means the number' \
                        ' of characters.')

    parser.add_argument('--output-path', type=str, default=None,
                       help='Directory where the output should go')
    parser.add_argument('--log-interval', type=int, default=100,
                       help='Log interval')

    args = parser.parse_args()

    print('cleanup dataset ...')

    for input_file in args.input_files:
        input_filename, input_filename_ext = os.path.splitext(Path(input_file)\
            .name)

        output_f_cleaned = os.path.join(args.output_path, input_filename + \
            "_cleaned" + input_filename_ext)
        output_f_filtered = os.path.join(args.output_path, input_filename + \
            "_filtered" + input_filename_ext)

        process_set(args, input_file, output_f_cleaned, output_f_filtered)

    print('done :-)', flush=True)
+403 −84

File changed.

Preview size limit exceeded, changes collapsed.

+195 −54
Original line number Diff line number Diff line
@@ -14,13 +14,16 @@
# limitations under the License.

import argparse
from functools import partial
import itertools
import json
from lsh import cache, minhash
import multiprocessing
import numpy as np
import time
import pickle
import sys
import os

# This function is adapted from:
#   https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
@@ -31,11 +34,158 @@ def shingles(text, char_ngram=5):

# This function is adapted from:
#  https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def jaccard(set_a, set_b):
def jaccard(set_a, set_b, args):
    if len(set_a) < 1 or len(set_b) < 1:
        return 0.0

    intersection = set_a & set_b
    union = set_a | set_b

    if args.jaccard == 'min':
        return len(intersection) / min(len(set_a), len(set_b))
    elif args.jaccard == 'max':
        return len(intersection) / max(len(set_a), len(set_b))
    else:
        return len(intersection) / len(union)

def compute_fingerprint(line, key):
    try:
        myjson = json.loads(line)
        url = myjson[key]
        text = myjson['text']
        fingerprint = hasher.fingerprint(text)
    except Exception as e:
        print('Error:', e)
        return None, None, None, False

    return url, text, fingerprint, True

def url_pairs_to_remove(args, bucket_urls, url_doc):
    remove_urls_list = []
    deduped_local, counter_local = 0, 0
    iteration = 0
    while len(bucket_urls) > 1:
        if args.heuristic_iter != -1 and \
            iteration == args.heuristic_iter:
            break

        items = list(bucket_urls)
        remove_urls = []
        main_url = items[np.random.randint(0, len(items))]
        main_dhingles = shingles(url_doc[main_url])

        for i in range(0, len(items)):
            counter_local += 1
            other_url = items[i]
            if other_url == main_url:
                continue
            other_shingles = shingles(url_doc[other_url])
            try:
                jaccard_sim = jaccard(main_dhingles, other_shingles, args)
            except Exception as e:
                print('Error:', e)
                jaccard_sim = 0.0
            if jaccard_sim > 0.5:
                remove_urls.append({other_url: jaccard_sim})
                deduped_local += 1
                bucket_urls.remove(other_url)

        bucket_urls.remove(main_url)
        if len(remove_urls) > 0:
            remove_urls_list.append({main_url: remove_urls})
        iteration += 1
    return remove_urls_list, deduped_local, counter_local

def write_remove_urls_list(remove_urls_list, f_out):
    if len(remove_urls_list) > 0:
        for each_url_remove in remove_urls_list:
            myjson = json.dumps(each_url_remove, ensure_ascii=False)
            f_out.write(myjson.encode('utf-8'))
            f_out.write('\n'.encode('utf-8'))

def compute_jaccard(each_bin, num_bins, start_time_local):

    remove_urls_list = []
    deduped_local, counter_local, bucket_local = 0, 0, 0

    for bucket_id in each_bin:
        bucket_local += 1
        if os.getpid() % num_bins == 0 and bucket_local % 100000 == 0:
            print("Counter {}, progress {:.2f} time {:.2f}".\
                format(bucket_local, float(bucket_local)/float(len(each_bin)),\
                time.time() - start_time_local), flush=True)

        if len(each_bin[bucket_id]) <= 1:
            continue

        bucket_urls = each_bin[bucket_id].copy()
        remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
            url_pairs_to_remove(args, bucket_urls, url_doc)

        deduped_local += deduped_local_sub
        counter_local += counter_local_sub
        if len(remove_urls_list_sub) > 0:
            remove_urls_list.extend(remove_urls_list_sub)

    return remove_urls_list, deduped_local, counter_local

def find_pair_urls_parallel(args, lshcache, url_doc):
    start_time = time.time()
    f_out = open(args.output, 'wb')
    deduped, counter = 0, 0

    # compute jaccards of buckets in bin in parallel (parallelism
    # limited to # of bins)
    num_bins = len(lshcache.bins)
    pool = multiprocessing.Pool(num_bins)
    compute_jaccard_partial = partial(compute_jaccard, num_bins=num_bins, \
        start_time_local=start_time)
    # don't need to pass args and url_doc as they are already shared
    compute_jaccard_iter = pool.imap(compute_jaccard_partial, lshcache.bins)

    print("multiprocessing init took {:.2f}".format(time.time() - start_time),\
        flush=True)
    for remove_urls_list, deduped_local, counter_local in compute_jaccard_iter:
        deduped += deduped_local
        counter += counter_local
        write_remove_urls_list(remove_urls_list, f_out)
        print(' [write]> processed {} documents in {:.2f} '
            'seoncds and deduped {} documents ...'.format(counter, time.time()\
            - start_time, deduped), flush=True)

    pool.close()
    pool.join()
    f_out.close()

    print(' Taken time for jaccard similariries {:.2f} seconds'.format(\
        time.time() - start_time), flush=True)

def find_pair_urls_sequential(args, lshcache, url_doc):
    start_time = time.time()
    f_out = open(args.output, 'wb')
    deduped, counter = 0, 0
    for b in lshcache.bins:
        for bucket_id in b:
            if len(b[bucket_id]) <= 1:
                continue

            bucket_urls = b[bucket_id].copy()
            remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
                url_pairs_to_remove(args, bucket_urls, url_doc)

            deduped += deduped_local_sub
            counter += counter_local_sub
            write_remove_urls_list(remove_urls_list_sub, f_out)
            if counter % 10000 == 0:
                print(' [write]> processed {} documents in {:.2f} '
                    'seoncds and deduped {} documents ...'.
                    format(counter, time.time() - start_time,
                    deduped), flush=True)
    f_out.close()
    print(' [write]> processed {} documents in {:.2f} '
        'seoncds and deduped {} documents ...'.
        format(counter, time.time() - start_time,
        deduped), flush=True)

if __name__ == '__main__':

@@ -55,17 +205,30 @@ if __name__ == '__main__':
    parser.add_argument('--output', type=str, default=None,
                       help='Output file name that consists of all ids'
                        ' with matching similarities')
    parser.add_argument('--jaccard', type=str, default='union',
                        choices=['union', 'min', 'max'], help='Jaccard'\
                        ' similarity computation')
    parser.add_argument('--heuristic-iter', type=int, default=1,
                       help='Number of iterations to run the heuristics'
                        ': use -1 for exact')
    parser.add_argument('--num-bands', type=int, default=10,
                       help='Number of bands to use in cache')
    parser.add_argument('--num-seeds', type=int, default=100,
                       help='Number of seeds to use for minhash. Note that'
                        ' this value should be divisible by num-bands')
    parser.add_argument('--jaccard-parallel', action='store_true',
                       help='Use this to process large number of documents.')
    args = parser.parse_args()

    print('finding possible duplicate content ...')

    # set seed and get an array of seeds of 100 integers
    np.random.seed(args.seed)
    seeds = np.random.randint(0, 1e6, size=100)
    seeds = np.random.randint(0, 1e6, size=args.num_seeds)

    # initialize minhash and lsh cache
    hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
    lshcache = cache.Cache(bands=10, hasher=hasher)
    lshcache = cache.Cache(num_bands=args.num_bands, hasher=hasher)

    url_doc = {}

@@ -91,32 +254,37 @@ if __name__ == '__main__':
    counter = 0
    start_time = time.time()

    print("Computing fingerprints", flush=True)

    # compute finger prints of the inputs if any
    # input file and the key to use as id
    if args.inputs is not None:
        print("Computing fingerprints", flush=True)
        assert len(args.inputs) % 2 == 0
        for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
            print(' document processing {} with key {}'.format(input_file, key),
                flush=True)

            # compute fingerprints in parallel
            num_workers = 40
            pool = multiprocessing.Pool(num_workers)
            fin = open(input_file, 'r', encoding='utf-8')
            compute_fingerprint_partial = partial(compute_fingerprint, key=key)
            compute_fingerprint_iter = pool.imap(compute_fingerprint_partial,
                                                    fin, 512)
            # traverse all the texts and add fingerprints
            with open(input_file, 'r') as f_input:
                for line in f_input:
                    try:
                        myjson = json.loads(line)
                        url = myjson[key]
                        text = myjson['text']
            for url, text, fingerprint, flag in compute_fingerprint_iter:
                counter += 1
                if flag:
                    url_doc[url] = text
                        lshcache.add_fingerprint(hasher.fingerprint(text), url)
                    except Exception as e:
                        print('Error:', e)
                    lshcache.add_fingerprint(fingerprint, url)
                if counter % 10000 == 0:
                    print(' [read]> processed {} documents in {:.2f} '
                        'seconds ...'.format(counter, time.time() - \
                        start_time), flush=True)

            fin.close()
            pool.close()
            pool.join()

    # Save the fingerprints if needed
    if args.save_fingerprints is not None:
        print("Saving fingerprints to pickle file {}".format(
@@ -125,40 +293,13 @@ if __name__ == '__main__':
            pickle.dump(lshcache, f_save)
            pickle.dump(url_doc, f_save)

    counter = 0
    start_time = time.time()
    deduped = 0
    # compute jaccard index of the input texts and write to file if needed
    if args.output is not None:
        f_out = open(args.output, 'wb')
        for b in lshcache.bins:
            for bucket_id in b:
                if len(b[bucket_id]) > 1:
                    items = list(b[bucket_id])
                    main_url = items[0]
                    main_dhingles = shingles(url_doc[main_url])
                    remove_urls = []
                    for i in range(1, len(items)):
                        counter += 1
                        other_url= items[i]
                        other_shingles = shingles(url_doc[other_url])
                        try:
                            jaccard_sim = jaccard(main_dhingles, other_shingles)
                        except Exception as e:
                            print('Error:', e)
                        if jaccard_sim > 0.5:
                            remove_urls.append({other_url: jaccard_sim})
                            deduped += 1
                        if counter % 10000 == 0:
                            print(' [write]> processed {} documents in {:.2f} '
                                  'seoncds and deduped {} documents ...'.
                                  format(counter, time.time() - start_time,
                                         deduped), flush=True)
                    if len(remove_urls) > 0:
                        myjson = json.dumps({main_url: remove_urls},
                                            ensure_ascii=False)
                        f_out.write(myjson.encode('utf-8'))
                        f_out.write('\n'.encode('utf-8'))
        f_out.close()
        print("Compute jaccard similarity", flush=True)
        if args.jaccard_parallel:
            find_pair_urls_parallel(args, lshcache, url_doc)
        else:
            find_pair_urls_sequential(args, lshcache, url_doc)

    print('done :-)')
 
Loading