Commit de6640be authored by Neel Kant's avatar Neel Kant
Browse files

Made topk accuracy reporting optional

parent 8d7f508a
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -414,5 +414,9 @@ def _add_realm_args(parser):
    group.add_argument('--ict-one-sent', action='store_true',
                       help='Whether to use one sentence documents in ICT')

    # training
    group.add_argument('--report-topk-accuracies', nargs='+', default=[],
                       help="Which top-k accuracies to report (e.g. '1 5 20')")

    return parser
+7 −10
Original line number Diff line number Diff line
@@ -116,19 +116,16 @@ def forward_step(data_iterator, model):
    softmaxed = F.softmax(retrieval_scores, dim=1)
    sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True)

    def topk_acc(k):
    def topk_accuracy(k):
        return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size])
    top_accs = [topk_acc(k) for k in [1, 8, 20, 100]]

    topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies]
    retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
    reduced_losses = reduce_losses([retrieval_loss, *top_accs])
    stats_dict = {
        'retrieval loss': reduced_losses[0],
        'top1_acc': reduced_losses[1],
        'top8_acc': reduced_losses[2],
        'top20_acc': reduced_losses[3],
        'top100_acc': reduced_losses[4],
    }
    reduced_losses = reduce_losses([retrieval_loss, *topk_accs])

    # create stats_dict with retrieval loss and all specified top-k accuracies
    topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, reduced_losses[1:])}
    stats_dict = dict(retrieval_loss=reduced_losses[0], **topk_acc_dict)

    return retrieval_loss, stats_dict