Commit 6ae4a03c authored by Mohammad's avatar Mohammad
Browse files

Merge branch 'staging' into staging_gpt2_dataloader

parents e06159f2 7890681a
Loading
Loading
Loading
Loading
+18 −13
Original line number Diff line number Diff line
@@ -59,6 +59,7 @@ def _initialize_distributed():
    """Initialize torch.distributed and mpu."""
    args = get_args()

    device_count = torch.cuda.device_count()
    if torch.distributed.is_initialized():

        if args.rank == 0:
@@ -66,8 +67,9 @@ def _initialize_distributed():
                  'skipping initialization ...', flush=True)
        args.rank = torch.distributed.get_rank()
        args.world_size = torch.distributed.get_world_size()
        if device_count > 0:
            device = torch.cuda.current_device()
        local_rank = args.rank % torch.cuda.device_count()
            local_rank = args.rank % device_count
            assert local_rank == device, \
                'expected local-rank to be the same as rank % device-count.'

@@ -76,7 +78,8 @@ def _initialize_distributed():
        if args.rank == 0:
            print('> initializing torch distributed ...', flush=True)
        # Manually set the device ids.
        device = args.rank % torch.cuda.device_count()
        if device_count > 0:
            device = args.rank % device_count
            if args.local_rank is not None:
                assert args.local_rank == device, \
                    'expected local-rank to be the same as rank % device-count.'
@@ -94,6 +97,7 @@ def _initialize_distributed():
            init_method=init_method)

    # Set the model-parallel / data-parallel communicators.
    if device_count > 0:
        mpu.initialize_model_parallel(args.model_parallel_size)


@@ -112,6 +116,7 @@ def _set_random_seed(seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.device_count() > 0:
            mpu.model_parallel_cuda_manual_seed(seed)
    else:
        raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
+145 −0
Original line number Diff line number Diff line
import os
import argparse
import collections

import numpy as np
import torch

def process_files(args):
    all_predictions = collections.OrderedDict()
    all_labels = collections.OrderedDict()
    all_uid = collections.OrderedDict()
    for path in args.paths:
        path = os.path.join(path, args.prediction_name)
        try:
            data = torch.load(path)
            for dataset in data:
                name, d = dataset
                predictions, labels, uid = d
                if name not in all_predictions:
                    all_predictions[name] = np.array(predictions)
                    if args.labels is None:
                        args.labels = [i for i in range(all_predictions[name].shape[1])]
                    if args.eval:
                        all_labels[name] = np.array(labels)
                    all_uid[name] = np.array(uid)
                else:
                    all_predictions[name] += np.array(predictions)
                    assert np.allclose(all_uid[name], np.array(uid))
        except Exception as e:
            print(e)
            continue
    return all_predictions, all_labels, all_uid


def get_threshold(all_predictions, all_labels, one_threshold=False):
    if one_threshold:
        all_predictons = {'combined': np.concatenate(list(all_predictions.values()))}
        all_labels = {'combined': np.concatenate(list(all_predictions.labels()))}
    out_thresh = []
    for dataset in all_predictions:
        preds = all_predictions[dataset]
        labels = all_labels[dataset]
        out_thresh.append(calc_threshold(preds,labels))
    return out_thresh


def calc_threshold(p, l):
    trials = [(i)*(1./100.) for i in range(100)]
    best_acc = float('-inf')
    best_thresh = 0
    for t in trials:
        acc = ((apply_threshold(p, t).argmax(-1) == l).astype(float)).mean()
        if acc > best_acc:
            best_acc = acc
            best_thresh = t
    return best_thresh


def apply_threshold(preds, t):
    assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0])))
    prob = preds[:,-1]
    thresholded = (prob >= t).astype(int)
    preds = np.zeros_like(preds)
    preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1
    return preds


def threshold_predictions(all_predictions, threshold):
    if len(threshold)!=len(all_predictions):
        threshold = [threshold[-1]]*(len(all_predictions)-len(threshold))
    for i, dataset in enumerate(all_predictions):
        thresh = threshold[i]
        preds = all_predictions[dataset]
        all_predictions[dataset] = apply_threshold(preds, thresh)
    return all_predictions


def postprocess_predictions(all_predictions, all_labels, args):
    for d in all_predictions:
        all_predictions[d] = all_predictions[d]/len(args.paths)

    if args.calc_threshold:
        args.threshold = get_threshold(all_predictions, all_labels, args.one_threshold)
        print('threshold', args.threshold)

    if args.threshold is not None:
        all_predictions = threshold_predictions(all_predictions, args.threshold)

    return all_predictions, all_labels


def write_predictions(all_predictions, all_labels, all_uid, args):
    all_correct = 0
    count = 0
    for dataset in all_predictions:
        preds = all_predictions[dataset]
        preds = np.argmax(preds, -1)
        if args.eval:
            correct = (preds == all_labels[dataset]).sum()
            num = len(all_labels[dataset])
            accuracy = correct/num
            count += num
            all_correct += correct
            accuracy = (preds == all_labels[dataset]).mean()
            print(accuracy)
        if not os.path.exists(os.path.join(args.outdir, dataset)):
            os.makedirs(os.path.join(args.outdir, dataset))
        outpath = os.path.join(args.outdir, dataset, os.path.splitext(args.prediction_name)[0]+'.tsv')
        with open(outpath, 'w') as f:
            f.write('id\tlabel\n')
            f.write('\n'.join(str(uid)+'\t'+str(args.labels[p]) for uid, p in zip(all_uid[dataset], preds.tolist())))
    if args.eval:
        print(all_correct/count)


def ensemble_predictions(args):
    all_predictions, all_labels, all_uid = process_files(args)
    all_predictions, all_labels = postprocess_predictions(all_predictions, all_labels, args)
    write_predictions(all_predictions, all_labels, all_uid, args)


def  main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--paths', required=True, nargs='+',
                        help='paths to checkpoint directories used in ensemble')
    parser.add_argument('--eval', action='store_true',
                        help='compute accuracy metrics against labels (dev set)')
    parser.add_argument('--outdir',
                        help='directory to place ensembled predictions in')
    parser.add_argument('--prediction-name', default='test_predictions.pt',
                        help='name of predictions in checkpoint directories')
    parser.add_argument('--calc-threshold', action='store_true',
                        help='calculate threshold classification')
    parser.add_argument('--one-threshold', action='store_true',
                        help='use on threshold for all subdatasets')
    parser.add_argument('--threshold', nargs='+', default=None, type=float,
                        help='user supplied threshold for classification')
    parser.add_argument('--labels',nargs='+', default=None,
                        help='whitespace separated list of label names')
    args = parser.parse_args()
    ensemble_predictions(args)


if __name__ == '__main__':
    main()
 No newline at end of file