Commit 882683dc authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Added more feature in train data deduplication

parent 47e7eb80
Loading
Loading
Loading
Loading
+48 −0
Original line number Diff line number Diff line
import argparse
import json
import os
import time


if __name__ == '__main__':

    print('parsing the arguments ...')

    parser = argparse.ArgumentParser()
    parser.add_argument('--input-file', type=str, default=None, help='Input'\
        ' json file where id needs to be added')
    parser.add_argument('--output-file', type=str, default=None, help=\
        'Output file name with id')
    parser.add_argument('--id-prefix', type=str, default=None, help=\
        'Id prefix')
    parser.add_argument('--log-interval', type=int, default=100,
                       help='Log interval')
    args = parser.parse_args()

    print('Adding ids to dataset ...')

    f_input = open(args.input_file, 'r', encoding='utf-8')
    f_output = open(args.output_file, 'wb')

    unique_ids = 1
    start_time = time.time()
    for row in f_input:
        each_row = json.loads(row)
        adlr_id_string = args.id_prefix + '-{:010d}'.format(int(unique_ids))
        each_row['adlr_id'] = adlr_id_string
        myjson = json.dumps(each_row, ensure_ascii=False)

        f_output.write(myjson.encode('utf-8'))
        f_output.write('\n'.encode('utf-8'))

        if unique_ids % args.log_interval == 0:
            print('    processed {:9d} documents in {:.2f} seconds ...'.format( \
                    unique_ids, time.time() - start_time), flush=True)

        unique_ids += 1

    # Close the file.
    f_input.close()
    f_output.close()
    
    print('done :-)', flush=True)
+89 −29
Original line number Diff line number Diff line
@@ -14,9 +14,11 @@
# limitations under the License.

import argparse
from functools import partial
import itertools
import json
from lsh import cache, minhash
import multiprocessing
import numpy as np
import time
import pickle
@@ -31,11 +33,31 @@ def shingles(text, char_ngram=5):

# This function is adapted from:
#  https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def jaccard(set_a, set_b):
def jaccard(set_a, set_b, args):
    if len(set_a) < 1 or len(set_b) < 1:
        return 0.0

    intersection = set_a & set_b
    union = set_a | set_b

    if args.jaccard == 'min':
        return len(intersection) / min(len(set_a), len(set_b))
    elif args.jaccard == 'max':
        return len(intersection) / max(len(set_a), len(set_b))
    else:
        return len(intersection) / len(union)

def compute_fingerprint(line, key):
    try:
        myjson = json.loads(line)
        url = myjson[key]
        text = myjson['text']
        fingerprint = hasher.fingerprint(text)
    except Exception as e:
        print('Error:', e)
        return None, None, None, False

    return url, text, fingerprint, True

if __name__ == '__main__':

@@ -55,17 +77,29 @@ if __name__ == '__main__':
    parser.add_argument('--output', type=str, default=None,
                       help='Output file name that consists of all ids'
                        ' with matching similarities')
    parser.add_argument('--jaccard', type=str, default='union',
                        choices=['union', 'min', 'max'], help='Jaccard'\
                        ' similarity computation')
    parser.add_argument('--heuristic-iter', type=int, default=1,
                       help='Number of iterations to run the heuristics'
                        ': use -1 for exact')
    parser.add_argument('--num-bands', type=int, default=10,
                       help='Number of bands to use in cache')
    parser.add_argument('--num-seeds', type=int, default=100,
                       help='Number of seeds to use for minhash. Note that'
                        ' this value should be divisible by num-bands')

    args = parser.parse_args()

    print('finding possible duplicate content ...')

    # set seed and get an array of seeds of 100 integers
    np.random.seed(args.seed)
    seeds = np.random.randint(0, 1e6, size=100)
    seeds = np.random.randint(0, 1e6, size=args.num_seeds)

    # initialize minhash and lsh cache
    hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
    lshcache = cache.Cache(bands=10, hasher=hasher)
    lshcache = cache.Cache(num_bands=args.num_bands, hasher=hasher)

    url_doc = {}

@@ -100,23 +134,29 @@ if __name__ == '__main__':
        for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
            print(' document processing {} with key {}'.format(input_file, key),
                flush=True)

            # compute fingerprints in parallel
            num_workers = 20
            pool = multiprocessing.Pool(num_workers)
            fin = open(input_file, 'r', encoding='utf-8')
            compute_fingerprint_partial = partial(compute_fingerprint, key=key)
            compute_fingerprint_iter = pool.imap(compute_fingerprint_partial,
                                                    fin, 500)
            # traverse all the texts and add fingerprints
            with open(input_file, 'r') as f_input:
                for line in f_input:
                    try:
                        myjson = json.loads(line)
                        url = myjson[key]
                        text = myjson['text']
            for url, text, fingerprint, flag in compute_fingerprint_iter:
                counter += 1
                if flag:
                    url_doc[url] = text
                        lshcache.add_fingerprint(hasher.fingerprint(text), url)
                    except Exception as e:
                        print('Error:', e)
                    lshcache.add_fingerprint(fingerprint, url)
                if counter % 10000 == 0:
                    print(' [read]> processed {} documents in {:.2f} '
                        'seconds ...'.format(counter, time.time() - \
                        start_time), flush=True)

            fin.close()
            pool.close()
            pool.join()

    # Save the fingerprints if needed
    if args.save_fingerprints is not None:
        print("Saving fingerprints to pickle file {}".format(
@@ -133,32 +173,52 @@ if __name__ == '__main__':
        f_out = open(args.output, 'wb')
        for b in lshcache.bins:
            for bucket_id in b:
                if len(b[bucket_id]) > 1:
                    items = list(b[bucket_id])
                    main_url = items[0]
                    main_dhingles = shingles(url_doc[main_url])
                if len(b[bucket_id]) <= 1:
                    continue

                bucket_urls = b[bucket_id].copy()
                iteration = 0
                while len(bucket_urls) > 1:
                    if args.heuristic_iter != -1 and \
                        iteration == args.heuristic_iter:
                        break

                    items = list(bucket_urls)
                    remove_urls = []
                    for i in range(1, len(items)):
                    main_url = items[np.random.randint(0, len(items))]
                    main_dhingles = shingles(url_doc[main_url])

                    for i in range(0, len(items)):
                        counter += 1
                        other_url= items[i]
                        if other_url == main_url:
                            continue
                        other_shingles = shingles(url_doc[other_url])
                        try:
                            jaccard_sim = jaccard(main_dhingles, other_shingles)
                            jaccard_sim = jaccard(main_dhingles, other_shingles,
                                                    args)
                        except Exception as e:
                            print('Error:', e)
                            jaccard_sim = 0.0
                        if jaccard_sim > 0.5:
                            remove_urls.append({other_url: jaccard_sim})
                            deduped += 1
                            bucket_urls.remove(other_url)
                        if counter % 10000 == 0:
                            print(' [write]> processed {} documents in {:.2f} '
                                'seoncds and deduped {} documents ...'.
                                format(counter, time.time() - start_time,
                                deduped), flush=True)

                    bucket_urls.remove(main_url)
                    if len(remove_urls) > 0:
                        myjson = json.dumps({main_url: remove_urls},
                                        ensure_ascii=False)
                        f_out.write(myjson.encode('utf-8'))
                        f_out.write('\n'.encode('utf-8'))
                    iteration += 1

        f_out.close()

    print('done :-)')
 
+5 −5
Original line number Diff line number Diff line
@@ -18,10 +18,6 @@ import time
import sys


def is_similar(jaccard_similarity):
    return (jaccard_similarity >= 0.7)


if __name__ == '__main__':


@@ -29,6 +25,10 @@ if __name__ == '__main__':

    input = sys.argv[1]
    output = sys.argv[2]
    if len(sys.argv) > 3:
        jaccard_similarity_threshold = float(sys.argv[3])
    else:
        jaccard_similarity_threshold = 0.7

    url_to_index = {}
    index_to_urls = []
@@ -43,7 +43,7 @@ if __name__ == '__main__':
                urls.append(main_url)
                for value in myjson[main_url]:
                    for other_url, js in value.items():
                        if is_similar(js):
                        if js >= jaccard_similarity_threshold:
                            urls.append(other_url)
            current_index = -1
            other_indices = set()