Loading tasks/ensemble_classifier.py +101 −65 Original line number Diff line number Diff line import torch import os import numpy as np import argparse import collections parser = argparse.ArgumentParser() parser.add_argument('--paths', required=True, nargs='+') parser.add_argument('--eval', action='store_true') parser.add_argument('--outdir') parser.add_argument('--prediction-name', default='test_predictions.pt') parser.add_argument('--calc-threshold', action='store_true') parser.add_argument('--one-threshold', action='store_true') parser.add_argument('--threshold', nargs='+', default=None, type=float) parser.add_argument('--labels',nargs='+', default=None) args = parser.parse_args() import numpy as np import torch def process_files(args): all_predictions = collections.OrderedDict() all_labels = collections.OrderedDict() all_uid = collections.OrderedDict() Loading @@ -38,10 +29,11 @@ for path in args.paths: except Exception as e: print(e) continue all_correct = 0 count = 0 def get_threshold(all_predictions, all_labels): if args.one_threshold: return all_predictions, all_labels, all_uid def get_threshold(all_predictions, all_labels, one_threshold=False): if one_threshold: all_predictons = {'combined': np.concatenate(list(all_predictions.values()))} all_labels = {'combined': np.concatenate(list(all_predictions.labels()))} out_thresh = [] Loading @@ -50,6 +42,8 @@ def get_threshold(all_predictions, all_labels): labels = all_labels[dataset] out_thresh.append(calc_threshold(preds,labels)) return out_thresh def calc_threshold(p, l): trials = [(i)*(1./100.) for i in range(100)] best_acc = float('-inf') Loading @@ -61,6 +55,7 @@ def calc_threshold(p, l): best_thresh = t return best_thresh def apply_threshold(preds, t): assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0]))) prob = preds[:,-1] Loading @@ -69,6 +64,7 @@ def apply_threshold(preds, t): preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1 return preds def threshold_predictions(all_predictions, threshold): if len(threshold)!=len(all_predictions): threshold = [threshold[-1]]*(len(all_predictions)-len(threshold)) Loading @@ -78,16 +74,24 @@ def threshold_predictions(all_predictions, threshold): all_predictions[dataset] = apply_threshold(preds, thresh) return all_predictions def postprocess_predictions(all_predictions, all_labels, args): for d in all_predictions: all_predictions[d] = all_predictions[d]/len(args.paths) if args.calc_threshold: args.threshold = get_threshold(all_predictions, all_labels) args.threshold = get_threshold(all_predictions, all_labels, args.one_threshold) print('threshold', args.threshold) if args.threshold is not None: all_predictions = threshold_predictions(all_predictions, args.threshold) return all_predictions, all_labels def write_predictions(all_predictions, all_labels, all_uid, args): all_correct = 0 count = 0 for dataset in all_predictions: preds = all_predictions[dataset] preds = np.argmax(preds, -1) Loading @@ -107,3 +111,35 @@ for dataset in all_predictions: f.write('\n'.join(str(uid)+'\t'+str(args.labels[p]) for uid, p in zip(all_uid[dataset], preds.tolist()))) if args.eval: print(all_correct/count) def ensemble_predictions(args): all_predictions, all_labels, all_uid = process_files(args) all_predictions, all_labels = postprocess_predictions(all_predictions, all_labels, args) write_predictions(all_predictions, all_labels, all_uid, args) def main(): parser = argparse.ArgumentParser() parser.add_argument('--paths', required=True, nargs='+', help='paths to checkpoint directories used in ensemble') parser.add_argument('--eval', action='store_true', help='compute accuracy metrics against labels (dev set)') parser.add_argument('--outdir', help='directory to place ensembled predictions in') parser.add_argument('--prediction-name', default='test_predictions.pt', help='name of predictions in checkpoint directories') parser.add_argument('--calc-threshold', action='store_true', help='calculate threshold classification') parser.add_argument('--one-threshold', action='store_true', help='use on threshold for all subdatasets') parser.add_argument('--threshold', nargs='+', default=None, type=float, help='user supplied threshold for classification') parser.add_argument('--labels',nargs='+', default=None, help='whitespace separated list of label names') args = parser.parse_args() ensemble_predictions(args) if __name__ == '__main__': main() No newline at end of file Loading
tasks/ensemble_classifier.py +101 −65 Original line number Diff line number Diff line import torch import os import numpy as np import argparse import collections parser = argparse.ArgumentParser() parser.add_argument('--paths', required=True, nargs='+') parser.add_argument('--eval', action='store_true') parser.add_argument('--outdir') parser.add_argument('--prediction-name', default='test_predictions.pt') parser.add_argument('--calc-threshold', action='store_true') parser.add_argument('--one-threshold', action='store_true') parser.add_argument('--threshold', nargs='+', default=None, type=float) parser.add_argument('--labels',nargs='+', default=None) args = parser.parse_args() import numpy as np import torch def process_files(args): all_predictions = collections.OrderedDict() all_labels = collections.OrderedDict() all_uid = collections.OrderedDict() Loading @@ -38,10 +29,11 @@ for path in args.paths: except Exception as e: print(e) continue all_correct = 0 count = 0 def get_threshold(all_predictions, all_labels): if args.one_threshold: return all_predictions, all_labels, all_uid def get_threshold(all_predictions, all_labels, one_threshold=False): if one_threshold: all_predictons = {'combined': np.concatenate(list(all_predictions.values()))} all_labels = {'combined': np.concatenate(list(all_predictions.labels()))} out_thresh = [] Loading @@ -50,6 +42,8 @@ def get_threshold(all_predictions, all_labels): labels = all_labels[dataset] out_thresh.append(calc_threshold(preds,labels)) return out_thresh def calc_threshold(p, l): trials = [(i)*(1./100.) for i in range(100)] best_acc = float('-inf') Loading @@ -61,6 +55,7 @@ def calc_threshold(p, l): best_thresh = t return best_thresh def apply_threshold(preds, t): assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0]))) prob = preds[:,-1] Loading @@ -69,6 +64,7 @@ def apply_threshold(preds, t): preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1 return preds def threshold_predictions(all_predictions, threshold): if len(threshold)!=len(all_predictions): threshold = [threshold[-1]]*(len(all_predictions)-len(threshold)) Loading @@ -78,16 +74,24 @@ def threshold_predictions(all_predictions, threshold): all_predictions[dataset] = apply_threshold(preds, thresh) return all_predictions def postprocess_predictions(all_predictions, all_labels, args): for d in all_predictions: all_predictions[d] = all_predictions[d]/len(args.paths) if args.calc_threshold: args.threshold = get_threshold(all_predictions, all_labels) args.threshold = get_threshold(all_predictions, all_labels, args.one_threshold) print('threshold', args.threshold) if args.threshold is not None: all_predictions = threshold_predictions(all_predictions, args.threshold) return all_predictions, all_labels def write_predictions(all_predictions, all_labels, all_uid, args): all_correct = 0 count = 0 for dataset in all_predictions: preds = all_predictions[dataset] preds = np.argmax(preds, -1) Loading @@ -107,3 +111,35 @@ for dataset in all_predictions: f.write('\n'.join(str(uid)+'\t'+str(args.labels[p]) for uid, p in zip(all_uid[dataset], preds.tolist()))) if args.eval: print(all_correct/count) def ensemble_predictions(args): all_predictions, all_labels, all_uid = process_files(args) all_predictions, all_labels = postprocess_predictions(all_predictions, all_labels, args) write_predictions(all_predictions, all_labels, all_uid, args) def main(): parser = argparse.ArgumentParser() parser.add_argument('--paths', required=True, nargs='+', help='paths to checkpoint directories used in ensemble') parser.add_argument('--eval', action='store_true', help='compute accuracy metrics against labels (dev set)') parser.add_argument('--outdir', help='directory to place ensembled predictions in') parser.add_argument('--prediction-name', default='test_predictions.pt', help='name of predictions in checkpoint directories') parser.add_argument('--calc-threshold', action='store_true', help='calculate threshold classification') parser.add_argument('--one-threshold', action='store_true', help='use on threshold for all subdatasets') parser.add_argument('--threshold', nargs='+', default=None, type=float, help='user supplied threshold for classification') parser.add_argument('--labels',nargs='+', default=None, help='whitespace separated list of label names') args = parser.parse_args() ensemble_predictions(args) if __name__ == '__main__': main() No newline at end of file