Commit ab754c8c authored by Raul Puri's avatar Raul Puri
Browse files

functionalized code

parent 5dd2a9ad
Loading
Loading
Loading
Loading
+101 −65
Original line number Diff line number Diff line
import torch
import os
import numpy as np
import argparse
import collections

parser = argparse.ArgumentParser()
parser.add_argument('--paths', required=True, nargs='+')
parser.add_argument('--eval', action='store_true')
parser.add_argument('--outdir')
parser.add_argument('--prediction-name', default='test_predictions.pt')
parser.add_argument('--calc-threshold', action='store_true')
parser.add_argument('--one-threshold', action='store_true')
parser.add_argument('--threshold', nargs='+', default=None, type=float)
parser.add_argument('--labels',nargs='+', default=None)
args = parser.parse_args()
import numpy as np
import torch

def process_files(args):
    all_predictions = collections.OrderedDict()
    all_labels = collections.OrderedDict()
    all_uid = collections.OrderedDict()
@@ -38,10 +29,11 @@ for path in args.paths:
        except Exception as e:
            print(e)
            continue
all_correct = 0
count = 0
def get_threshold(all_predictions, all_labels):
    if args.one_threshold:
    return all_predictions, all_labels, all_uid


def get_threshold(all_predictions, all_labels, one_threshold=False):
    if one_threshold:
        all_predictons = {'combined': np.concatenate(list(all_predictions.values()))}
        all_labels = {'combined': np.concatenate(list(all_predictions.labels()))}
    out_thresh = []
@@ -50,6 +42,8 @@ def get_threshold(all_predictions, all_labels):
        labels = all_labels[dataset]
        out_thresh.append(calc_threshold(preds,labels))
    return out_thresh


def calc_threshold(p, l):
    trials = [(i)*(1./100.) for i in range(100)]
    best_acc = float('-inf')
@@ -61,6 +55,7 @@ def calc_threshold(p, l):
            best_thresh = t
    return best_thresh


def apply_threshold(preds, t):
    assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0])))
    prob = preds[:,-1]
@@ -69,6 +64,7 @@ def apply_threshold(preds, t):
    preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1
    return preds


def threshold_predictions(all_predictions, threshold):
    if len(threshold)!=len(all_predictions):
        threshold = [threshold[-1]]*(len(all_predictions)-len(threshold))
@@ -78,16 +74,24 @@ def threshold_predictions(all_predictions, threshold):
        all_predictions[dataset] = apply_threshold(preds, thresh)
    return all_predictions


def postprocess_predictions(all_predictions, all_labels, args):
    for d in all_predictions:
        all_predictions[d] = all_predictions[d]/len(args.paths)

    if args.calc_threshold:
    args.threshold = get_threshold(all_predictions, all_labels)
        args.threshold = get_threshold(all_predictions, all_labels, args.one_threshold)
        print('threshold', args.threshold)

    if args.threshold is not None:
        all_predictions = threshold_predictions(all_predictions, args.threshold)

    return all_predictions, all_labels


def write_predictions(all_predictions, all_labels, all_uid, args):
    all_correct = 0
    count = 0
    for dataset in all_predictions:
        preds = all_predictions[dataset]
        preds = np.argmax(preds, -1)
@@ -107,3 +111,35 @@ for dataset in all_predictions:
            f.write('\n'.join(str(uid)+'\t'+str(args.labels[p]) for uid, p in zip(all_uid[dataset], preds.tolist())))
    if args.eval:
        print(all_correct/count)


def ensemble_predictions(args):
    all_predictions, all_labels, all_uid = process_files(args)
    all_predictions, all_labels = postprocess_predictions(all_predictions, all_labels, args)
    write_predictions(all_predictions, all_labels, all_uid, args)


def  main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--paths', required=True, nargs='+',
                        help='paths to checkpoint directories used in ensemble')
    parser.add_argument('--eval', action='store_true',
                        help='compute accuracy metrics against labels (dev set)')
    parser.add_argument('--outdir',
                        help='directory to place ensembled predictions in')
    parser.add_argument('--prediction-name', default='test_predictions.pt',
                        help='name of predictions in checkpoint directories')
    parser.add_argument('--calc-threshold', action='store_true',
                        help='calculate threshold classification')
    parser.add_argument('--one-threshold', action='store_true',
                        help='use on threshold for all subdatasets')
    parser.add_argument('--threshold', nargs='+', default=None, type=float,
                        help='user supplied threshold for classification')
    parser.add_argument('--labels',nargs='+', default=None,
                        help='whitespace separated list of label names')
    args = parser.parse_args()
    ensemble_predictions(args)


if __name__ == '__main__':
    main()
 No newline at end of file