Commit a1f4032b authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main_dedup' into 'main'

Storing and loading fingerprints of in deduplication

See merge request ADLR/megatron-lm!236
parents 90623142 ebc4a408
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -26,9 +26,9 @@ python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for
```
python cleanup_dataset.py <input data file> <output cleaned data filename>
```
2. Using LSH, find possible duplicates and store then in a file for later processing. This step can NOT be sharded and usually takes 12 to 24 hours for OpenWebText dataset.
2. Using LSH, find possible duplicates and store then in a file for later processing. This step can NOT be sharded and usually takes 12 to 24 hours for OpenWebText dataset. The code supports saving and loading fingerprints for recurrent deduplications.
```
python find_duplicates.py <input cleaned data file> <output possible duplicate urls filename>
python find_duplicates.py --inputs <pairlist list of input cleaned data files and keys, e.g. cc.json cc_id news.json news_id> --output <output possible duplicate urls filename>
```
3. Based on similarity measure defind inside function `is_similar` (default: 0.9), group urls that are similar. Basically, for each group, only one url we should keep and remove the rest.
```
+87 −23
Original line number Diff line number Diff line
@@ -13,14 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
import itertools
import json
from lsh import cache, minhash
import numpy as np
import time
import pickle
import sys


# This function is adapted from:
#   https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def shingles(text, char_ngram=5):
@@ -38,22 +39,73 @@ def jaccard(set_a, set_b):

if __name__ == '__main__':

    print('parsing the arguments ...')

    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=1234,
                       help='Random seed used for python, numpy')
    parser.add_argument('--inputs', nargs = '*', default=None, help = \
                        'Pairwise list of the input files and keys, '
                        'e.g. --inputs cc.json cc_id news.json news_id')
    parser.add_argument('--load-fingerprints', nargs = '*', default=None,
                       help='Load fingerprints from a list of pickle files,'
                        ' e.g. cc.pkl news.pkl')
    parser.add_argument('--save-fingerprints', type=str, default=None,
                       help='Save the fingerprints of the inputs.')
    parser.add_argument('--output', type=str, default=None,
                       help='Output file name that consists of all ids'
                        ' with matching similarities')
    args = parser.parse_args()

    print('finding possible duplicate content ...')

    input = sys.argv[1]
    output = sys.argv[2]
    # set seed and get an array of seeds of 100 integers
    np.random.seed(args.seed)
    seeds = np.random.randint(0, 1e6, size=100)

    hasher = minhash.MinHasher(seeds=100, char_ngram=5, hashbytes=4)
    # initialize minhash and lsh cache
    hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
    lshcache = cache.Cache(bands=10, hasher=hasher)

    counter = 0
    url_doc = {}

    # load fingerprints from pickle file if needed
    if args.load_fingerprints is not None:
        for count_fp, fp_file_name in enumerate(args.load_fingerprints):
            print("Loading fingerprints from pickle file {}".format(
                fp_file_name), flush=True)
            fp = open(fp_file_name, "rb")
            if count_fp == 0:
                # assign directory for the first pkl
                lshcache = pickle.load(fp)
                url_doc = pickle.load(fp)
            else:
                # append these to lshcache and url_doc
                local_lshcache = pickle.load(fp)
                local_url_doc = pickle.load(fp)
                for url in local_lshcache.fingerprints.keys():
                    url_doc[url] = local_url_doc[url]
                    lshcache.add_fingerprint(local_lshcache.fingerprints[url], url)
            fp.close()

    counter = 0
    start_time = time.time()
    with open(input, 'r') as f:
        for line in f:

    print("Computing fingerprints", flush=True)

    # compute finger prints of the inputs if any
    # input file and the key to use as id
    if args.inputs is not None:
        assert len(args.inputs) % 2 == 0
        for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
            print(' document processing {} with key {}'.format(input_file, key),
                flush=True)
            # traverse all the texts and add fingerprints
            with open(input_file, 'r') as f_input:
                for line in f_input:
                    try:
                        myjson = json.loads(line)
                url = myjson['url']
                        url = myjson[key]
                        text = myjson['text']
                        counter += 1
                        url_doc[url] = text
@@ -61,13 +113,24 @@ if __name__ == '__main__':
                    except Exception as e:
                        print('Error:', e)
                    if counter % 10000 == 0:
                print(' [read]> processed {} documents in {:.2f} seconds ...'.
                      format(counter, time.time() - start_time), flush=True)
                        print(' [read]> processed {} documents in {:.2f} '
                            'seconds ...'.format(counter, time.time() - \
                            start_time), flush=True)

    # Save the fingerprints if needed
    if args.save_fingerprints is not None:
        print("Saving fingerprints to pickle file {}".format(
            args.save_fingerprints), flush=True)
        with open(args.save_fingerprints, 'wb') as f_save:
            pickle.dump(lshcache, f_save)
            pickle.dump(url_doc, f_save)

    counter = 0
    start_time = time.time()
    deduped = 0
    with open(output, 'wb') as f:
    # compute jaccard index of the input texts and write to file if needed
    if args.output is not None:
        f_out = open(args.output, 'wb')
        for b in lshcache.bins:
            for bucket_id in b:
                if len(b[bucket_id]) > 1:
@@ -94,7 +157,8 @@ if __name__ == '__main__':
                    if len(remove_urls) > 0:
                        myjson = json.dumps({main_url: remove_urls},
                                            ensure_ascii=False)
                        f.write(myjson.encode('utf-8'))
                        f.write('\n'.encode('utf-8'))
                        f_out.write(myjson.encode('utf-8'))
                        f_out.write('\n'.encode('utf-8'))
        f_out.close()

    print('done :-)')