Commit 43d307d5 authored by mpatwary's avatar mpatwary
Browse files

added parallelism for computing jaccard similaity

parent 882683dc
Loading
Loading
Loading
Loading
+138 −57
Original line number Diff line number Diff line
@@ -23,6 +23,7 @@ import numpy as np
import time
import pickle
import sys
import os

# This function is adapted from:
#   https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
@@ -59,6 +60,133 @@ def compute_fingerprint(line, key):

    return url, text, fingerprint, True

def url_pairs_to_remove(args, bucket_urls, url_doc):
    remove_urls_list = []
    deduped_local, counter_local = 0, 0
    iteration = 0
    while len(bucket_urls) > 1:
        if args.heuristic_iter != -1 and \
            iteration == args.heuristic_iter:
            break

        items = list(bucket_urls)
        remove_urls = []
        main_url = items[np.random.randint(0, len(items))]
        main_dhingles = shingles(url_doc[main_url])

        for i in range(0, len(items)):
            counter_local += 1
            other_url = items[i]
            if other_url == main_url:
                continue
            other_shingles = shingles(url_doc[other_url])
            try:
                jaccard_sim = jaccard(main_dhingles, other_shingles, args)
            except Exception as e:
                print('Error:', e)
                jaccard_sim = 0.0
            if jaccard_sim > 0.5:
                remove_urls.append({other_url: jaccard_sim})
                deduped_local += 1
                bucket_urls.remove(other_url)

        bucket_urls.remove(main_url)
        if len(remove_urls) > 0:
            remove_urls_list.append({main_url: remove_urls})
        iteration += 1
    return remove_urls_list, deduped_local, counter_local

def write_remove_urls_list(remove_urls_list, f_out):
    if len(remove_urls_list) > 0:
        for each_url_remove in remove_urls_list:
            myjson = json.dumps(each_url_remove, ensure_ascii=False)
            f_out.write(myjson.encode('utf-8'))
            f_out.write('\n'.encode('utf-8'))

def compute_jaccard(each_bin, num_bins, start_time_local):

    remove_urls_list = []
    deduped_local, counter_local, bucket_local = 0, 0, 0

    for bucket_id in each_bin:
        bucket_local += 1
        if os.getpid() % num_bins == 0 and bucket_local % 100000 == 0:
            print("Counter {}, progress {:.2f} time {:.2f}".\
                format(bucket_local, float(bucket_local)/float(len(each_bin)),\
                time.time() - start_time_local), flush=True)

        if len(each_bin[bucket_id]) <= 1:
            continue

        bucket_urls = each_bin[bucket_id].copy()
        remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
            url_pairs_to_remove(args, bucket_urls, url_doc)

        deduped_local += deduped_local_sub
        counter_local += counter_local_sub
        if len(remove_urls_list_sub) > 0:
            remove_urls_list.extend(remove_urls_list_sub)

    return remove_urls_list, deduped_local, counter_local

def find_pair_urls_parallel(args, lshcache, url_doc):
    start_time = time.time()
    f_out = open(args.output, 'wb')
    deduped, counter = 0, 0

    # compute jaccards of buckets in bin in parallel (parallelism
    # limited to # of bins)
    num_bins = len(lshcache.bins)
    pool = multiprocessing.Pool(num_bins)
    compute_jaccard_partial = partial(compute_jaccard, num_bins=num_bins, \
        start_time_local=start_time)
    # don't need to pass args and url_doc as they are already shared
    compute_jaccard_iter = pool.imap(compute_jaccard_partial, lshcache.bins)

    print("multiprocessing init took {:.2f}".format(time.time() - start_time),\
        flush=True)
    for remove_urls_list, deduped_local, counter_local in compute_jaccard_iter:
        deduped += deduped_local
        counter += counter_local
        write_remove_urls_list(remove_urls_list, f_out)
        print(' [write]> processed {} documents in {:.2f} '
            'seoncds and deduped {} documents ...'.format(counter, time.time()\
            - start_time, deduped), flush=True)

    pool.close()
    pool.join()
    f_out.close()

    print(' Taken time for jaccard similariries {:.2f} seconds'.format(\
        time.time() - start_time), flush=True)

def find_pair_urls_sequential(args, lshcache, url_doc):
    start_time = time.time()
    f_out = open(args.output, 'wb')
    deduped, counter = 0, 0
    for b in lshcache.bins:
        for bucket_id in b:
            if len(b[bucket_id]) <= 1:
                continue

            bucket_urls = b[bucket_id].copy()
            remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
                url_pairs_to_remove(args, bucket_urls, url_doc)

            deduped += deduped_local_sub
            counter += counter_local_sub
            write_remove_urls_list(remove_urls_list_sub, f_out)
            if counter % 10000 == 0:
                print(' [write]> processed {} documents in {:.2f} '
                    'seoncds and deduped {} documents ...'.
                    format(counter, time.time() - start_time,
                    deduped), flush=True)
    f_out.close()
    print(' [write]> processed {} documents in {:.2f} '
        'seoncds and deduped {} documents ...'.
        format(counter, time.time() - start_time,
        deduped), flush=True)

if __name__ == '__main__':

    print('parsing the arguments ...')
@@ -88,7 +216,8 @@ if __name__ == '__main__':
    parser.add_argument('--num-seeds', type=int, default=100,
                       help='Number of seeds to use for minhash. Note that'
                        ' this value should be divisible by num-bands')

    parser.add_argument('--jaccard-parallel', action='store_true',
                       help='Use this to process large number of documents.')
    args = parser.parse_args()

    print('finding possible duplicate content ...')
@@ -125,23 +254,22 @@ if __name__ == '__main__':
    counter = 0
    start_time = time.time()

    print("Computing fingerprints", flush=True)

    # compute finger prints of the inputs if any
    # input file and the key to use as id
    if args.inputs is not None:
        print("Computing fingerprints", flush=True)
        assert len(args.inputs) % 2 == 0
        for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
            print(' document processing {} with key {}'.format(input_file, key),
                flush=True)

            # compute fingerprints in parallel
            num_workers = 20
            num_workers = 40
            pool = multiprocessing.Pool(num_workers)
            fin = open(input_file, 'r', encoding='utf-8')
            compute_fingerprint_partial = partial(compute_fingerprint, key=key)
            compute_fingerprint_iter = pool.imap(compute_fingerprint_partial,
                                                    fin, 500)
                                                    fin, 512)
            # traverse all the texts and add fingerprints
            for url, text, fingerprint, flag in compute_fingerprint_iter:
                counter += 1
@@ -165,60 +293,13 @@ if __name__ == '__main__':
            pickle.dump(lshcache, f_save)
            pickle.dump(url_doc, f_save)

    counter = 0
    start_time = time.time()
    deduped = 0
    # compute jaccard index of the input texts and write to file if needed
    if args.output is not None:
        f_out = open(args.output, 'wb')
        for b in lshcache.bins:
            for bucket_id in b:
                if len(b[bucket_id]) <= 1:
                    continue

                bucket_urls = b[bucket_id].copy()
                iteration = 0
                while len(bucket_urls) > 1:
                    if args.heuristic_iter != -1 and \
                        iteration == args.heuristic_iter:
                        break

                    items = list(bucket_urls)
                    remove_urls = []
                    main_url = items[np.random.randint(0, len(items))]
                    main_dhingles = shingles(url_doc[main_url])

                    for i in range(0, len(items)):
                        counter += 1
                        other_url= items[i]
                        if other_url == main_url:
                            continue
                        other_shingles = shingles(url_doc[other_url])
                        try:
                            jaccard_sim = jaccard(main_dhingles, other_shingles,
                                                    args)
                        except Exception as e:
                            print('Error:', e)
                            jaccard_sim = 0.0
                        if jaccard_sim > 0.5:
                            remove_urls.append({other_url: jaccard_sim})
                            deduped += 1
                            bucket_urls.remove(other_url)
                        if counter % 10000 == 0:
                            print(' [write]> processed {} documents in {:.2f} '
                                'seoncds and deduped {} documents ...'.
                                format(counter, time.time() - start_time,
                                deduped), flush=True)

                    bucket_urls.remove(main_url)
                    if len(remove_urls) > 0:
                        myjson = json.dumps({main_url: remove_urls},
                                        ensure_ascii=False)
                        f_out.write(myjson.encode('utf-8'))
                        f_out.write('\n'.encode('utf-8'))
                    iteration += 1

        f_out.close()
        print("Compute jaccard similarity", flush=True)
        if args.jaccard_parallel:
            find_pair_urls_parallel(args, lshcache, url_doc)
        else:
            find_pair_urls_sequential(args, lshcache, url_doc)

    print('done :-)')