Commit dc3daa58 authored by Isaac's avatar Isaac
Browse files

Adding functionality to HOC finetuning to calculate F1 statistics

parent 87273694
Loading
Loading
Loading
Loading
+4 −3
Original line number Diff line number Diff line
@@ -80,13 +80,13 @@ def _read_hoc(file_path,dataset_name):
        else:
            continue

    return data_x, data_y
    return data_x, data_y, abstract_ids

def process_single_datapath(datapath, MegatronTokenizer, max_seq_length, dataset_name):

    print_rank_0('   > working on {}'.format(datapath))
    start_time = time.time()
    data_x, data_y = _read_hoc(datapath,dataset_name)
    data_x, data_y, abstract_ids = _read_hoc(datapath,dataset_name)

    samples = []
    num_samples = 0
@@ -99,7 +99,8 @@ def process_single_datapath(datapath, MegatronTokenizer, max_seq_length, dataset
        ids, types, paddings = build_tokens_types_paddings_from_text(
            context, no_context, MegatronTokenizer,  max_seq_length)
        label = data_y[i]
        samples.append(build_sample_hoc(ids,types,paddings,label,num_samples))
        abstract_id = abstract_ids[i]
        samples.append(build_sample_hoc(ids,types,paddings,label,abstract_id))
        num_samples += 1
    
    elapsed_time = time.time() - start_time
+1 −2
Original line number Diff line number Diff line
@@ -114,7 +114,6 @@ def calculate_correct_answers(name, model, dataloader,
        loss_fcn = torch.nn.CrossEntropyLoss()
        num_classes = 10
        loss = None
        correct = np.zeros(num_classes)
        loss_dict = {}
        for i in range(num_classes):
            if loss is None:
@@ -141,7 +140,7 @@ def calculate_correct_answers(name, model, dataloader,
            batch_ = next(batch)
        except BaseException:
            batch_ = batch
        tokens, types, labels, attention_mask = process_batch(batch_)
        tokens, types, labels, attention_mask, abstract_ids = process_batch(batch_)

        # Forward model.
        args = get_args()
+264 −0
Original line number Diff line number Diff line
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Evaluation utilities."""

import os
import time
from functools import partial

import torch
import numpy as np

from megatron import get_args
from megatron import print_rank_last, is_last_rank
from megatron import mpu
#from megatron.schedules_output import get_forward_backward_func
from megatron.schedules import get_forward_backward_func
from tasks.blurb.hoc.finetune_utils import build_data_loader
from tasks.blurb.hoc.finetune_utils import process_batch
from megatron.utils import average_losses_across_data_parallel_group

from sklearn.metrics import f1_score


def accuracy_f1_func_provider(single_dataset_provider):
    """Provide function that calculates accuracies."""
    args = get_args()

    # Build dataloaders.
    datapaths = args.valid_data
    dataloaders = []
    for datapath in datapaths:
        dataset = single_dataset_provider(datapath)
        #Set batch_size to 1, when calculating F1 scores
        args.f1_micro_batch_size = 1
        args.f1_global_batch_size = args.f1_micro_batch_size*args.data_parallel_size
        dataloader = build_data_loader(
            dataset, args.f1_micro_batch_size, num_workers=args.num_workers,
            drop_last=(mpu.get_data_parallel_world_size() > 1))
        dataloaders.append((dataset.dataset_name, dataloader))
        #dataloader = build_data_loader(
        #    dataset, args.orig_micro_batch_size, num_workers=args.num_workers,
        #    drop_last=(mpu.get_data_parallel_world_size() > 1))
        #dataloaders.append((dataset.dataset_name, dataloader))

    def metrics_func(model, epoch, output_predictions=False):
        print_rank_last('calculating metrics ...')
        num_classes=10
        f1 = np.zeros(num_classes)
        total = 0
        correct = 0
        if output_predictions:
            assert mpu.get_data_parallel_world_size() == 1
            named_predictions = []
            names = 'predictions'
        
        for name, dataloader in dataloaders:
            output = calculate_correct_answers(name, model, dataloader,
                                               epoch, output_predictions)
            if not output_predictions:
                #correct_ans, total_count = output
                #f1_scores, correct_ans, total_count = output
                f1_scores, total_count = output
            else:
                correct_ans, total_count, predictions = output
                named_predictions.append((name, predictions))
                names += '_' + name
            if mpu.is_pipeline_last_stage():
            #if is_last_rank():
                for i in range(num_classes):
                    f1[i] += f1_scores[i]
                total += total_count
                #correct += correct_ans
        if is_last_rank():
            for i in range(num_classes):
                #percent = float(correct[i]) * 100.0 / float(total)
                print(' >> |epoch: {}| overall: correct / total = {} / {} | '
                    'F1 Scores: {:.4f} '.format(epoch, correct, total, f1[i]))
        #if is_last_rank():
        #    for i in range(num_classes):
        #        percent = float(correct[i]) * 100.0 / float(total)
        #        print(' >> |epoch: {}| overall: correct / total = {} / {} = '
        #            '{:.4f} %'.format(epoch, correct[i], total, percent))

        if output_predictions and is_last_rank():
            assert args.load is not None
            filename = os.path.join(args.load, names + '.pt')
            torch.save(named_predictions, filename)

    return metrics_func


def calculate_correct_answers(name, model, dataloader,
                              epoch, output_predictions):
    """Calculate correct over total answers and return prediction if the
    `output_predictions` is true."""
    args = get_args()
    forward_backward_func = get_forward_backward_func()
    start_time = time.time()
    for m in model:
        m.eval()
    saved_micro_batch_size = args.micro_batch_size
    saved_global_batch_size = args.global_batch_size

    ds = dataloader.dataset
    if hasattr(ds, 'sample_multiplier'):
        # If our dataset as a sample_multiplier attribute that means
        # each "sample" from the dataset actually has multiple samples
        # that will collapse into the batch dimension (for example in
        # the RACE dataset that has several options), we need to
        # account for that when setting the micro batch size.
        sample_multiplier = ds.sample_multiplier
    else:
        sample_multiplier = 1
    #micro_batch_size_times_data_parallel = args.orig_micro_batch_size * args.data_parallel_size
    micro_batch_size_times_data_parallel = args.f1_micro_batch_size * args.data_parallel_size
    #num_micro_batches = args.orig_global_batch_size // micro_batch_size_times_data_parallel
    num_micro_batches = args.f1_global_batch_size // micro_batch_size_times_data_parallel

    #def loss_func(output_predictions, labels, output_tensor, bs):
    def loss_func(output_predictions, labels, output_tensor):

        loss_fcn = torch.nn.CrossEntropyLoss()
        num_classes = 10
        loss_dict = {}
        loss = None
        for i in range(num_classes):
            if loss is None:
                loss = loss_fcn(output_tensor[:,i,:],labels[:,i])
            else:
                loss += loss_fcn(output_tensor[:,i,:],labels[:,i])

            predicted = torch.argmax(output_tensor[:,i,:], dim=-1).cpu()
            loss_dict['predicted{%d}' % i] = predicted

        loss_dict['total'] = labels.size(dim=0) 
        #loss_dict['total'] = bs

        #averaged_loss = average_losses_across_data_parallel_group([loss])
        #return loss, {'lm loss': averaged_loss[0]}, loss_dict
    
        return 0, loss_dict

    # defined inside to capture output_predictions
    def correct_answers_forward_step(batch, model):
        try:
            batch_ = next(batch)
        except BaseException:
            batch_ = batch
        tokens, types, labels, attention_mask, abstract_ids = process_batch(batch_)

        # Forward model.
        args = get_args()
        output_tensor = model(tokens, attention_mask, tokentype_ids=types)

        #bs = len(batch['label'])
        #return output_tensor, partial(loss_func, output_predictions, labels, bs)
        return output_tensor, partial(loss_func, output_predictions, labels)

    num_classes = 10
    abstract_scores = {}
    abstract_truth = {}
    correct = np.zeros(num_classes)
    total = 0
    with torch.no_grad():
        # For all the batches in the dataset.
        if output_predictions:
            # This option is only possible when data parallel size is 1.
            assert mpu.get_data_parallel_world_size() == 1
            softmaxes = []
            labels = []
            ids = []
        for _, batch in enumerate(dataloader):
            # For evaluation only mode we use drop_last = False to get all the
            # samples, which means we might not have a full batch, so we
            # adjust batch_size here to actual batch size of data

            # ... applying sample_multiplier if necessary
            actual_batch_size = len(batch['label'])
            args.micro_batch_size = actual_batch_size * sample_multiplier
            args.global_batch_size = actual_batch_size * sample_multiplier * num_micro_batches


            tokens,types,labels,attention_mask,abstract_ids = process_batch(batch)
            
            #output_tensor = forward_backward_func(correct_answers_forward_step, batch, model, 
            #                                   optimizer=None, timers=None, forward_only=True)
            loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model, 
                                               optimizer=None, timers=None, forward_only=True)

            abstract_id = abstract_ids.cpu()[0]
            batch_labels = labels.cpu()[0]

            if abstract_id not in abstract_scores:
                abstract_scores[abstract_id] = np.zeros(num_classes)
            if abstract_id not in abstract_truth:
                abstract_truth[abstract_id] = batch_labels
            else:
                abstract_truth[abstract_id] += batch_labels

            for loss_dict in loss_dicts:
                total += loss_dict['total']
                abstract_scores[abstract_id][0] += loss_dict['predicted{0}']
                abstract_scores[abstract_id][1] += loss_dict['predicted{1}']
                abstract_scores[abstract_id][2] += loss_dict['predicted{2}']
                abstract_scores[abstract_id][3] += loss_dict['predicted{3}']
                abstract_scores[abstract_id][4] += loss_dict['predicted{4}']
                abstract_scores[abstract_id][5] += loss_dict['predicted{5}']
                abstract_scores[abstract_id][6] += loss_dict['predicted{6}']
                abstract_scores[abstract_id][7] += loss_dict['predicted{7}']
                abstract_scores[abstract_id][8] += loss_dict['predicted{8}']
                abstract_scores[abstract_id][9] += loss_dict['predicted{9}']

    pred_labels = np.zeros((len(abstract_scores), num_classes), dtype=np.int32)
    actual_labels = np.zeros((len(abstract_scores), num_classes), dtype=np.int32)         
    for i,abstract_id in enumerate(abstract_scores.keys()):
        pred_labels[i,:] = np.clip(abstract_scores[abstract_id], 0, 1)
        actual_labels[i,:] = np.clip(abstract_truth[abstract_id], 0, 1)
        correct += 1.0*((abstract_scores[abstract_id] > 0) == (abstract_truth[abstract_id] > 0))

    f1 = np.zeros(num_classes)
    for j in range(num_classes):
        f1[j] = f1_score(actual_labels[:,j], pred_labels[:,j])

    for m in model:
        m.train()
    args.micro_batch_size = saved_micro_batch_size
    args.global_batch_size = saved_global_batch_size

    # Reduce.
    if mpu.is_pipeline_last_stage():
        f1_scores = np.zeros(num_classes)
        for i in range(num_classes):
            #unreduced = torch.cuda.LongTensor([correct, total])
            unreduced = torch.cuda.LongTensor([total])
            torch.distributed.all_reduce(unreduced,
                                        group=mpu.get_data_parallel_group())
            total_count = unreduced[0].item()

            unreducedFloat = torch.cuda.FloatTensor([f1[i]])
            torch.distributed.all_reduce(unreducedFloat,
                                        group=mpu.get_data_parallel_group())
            f1_scores[i] = unreducedFloat[0].item()
            elapsed_time = time.time() - start_time

        if output_predictions:
            return correct_ans, total_count, (softmaxes, labels, ids)
        #return f1_scores, correct, total_count
        return f1_scores, total_count
    if output_predictions:
        return 0, 0, ()
    return 0, 0
+20 −1
Original line number Diff line number Diff line
@@ -21,6 +21,7 @@ from megatron import get_tokenizer
from megatron import mpu
from megatron.model.classification_hoc import Classification_hoc
from tasks.blurb.hoc.eval_utils import accuracy_func_provider
from tasks.blurb.hoc.f1_utils import accuracy_f1_func_provider
from tasks.blurb.hoc.finetune_utils import finetune
from tasks.blurb.hoc.data import HOCDataset

@@ -56,6 +57,24 @@ def metrics_func_provider():
        return HOCDataset('dev', [datapath], tokenizer, args.seq_length)
    return accuracy_func_provider(single_dataset_provider)

def test_metrics_func_provider():
    args = get_args()
    tokenizer = get_tokenizer()

    def single_dataset_provider(datapath):
        return HOCDataset('test', [datapath], tokenizer, args.seq_length)
    return accuracy_func_provider(single_dataset_provider)

def f1_func_provider():
    args = get_args()
    tokenizer = get_tokenizer()

    def single_dataset_provider(datapath):
        return HOCDataset('test', [datapath], tokenizer, args.seq_length)
    return accuracy_f1_func_provider(single_dataset_provider)

def main():
    finetune(train_valid_datasets_provider, model_provider,
             end_of_epoch_callback_provider=metrics_func_provider)
             end_of_epoch_callback_provider=metrics_func_provider,
             end_of_training_eval_callback_provider=test_metrics_func_provider, 
             end_of_training_f1_callback_provider=f1_func_provider)
+21 −4
Original line number Diff line number Diff line
@@ -43,10 +43,11 @@ def process_batch(batch):
    types = batch['types'].long().cuda().contiguous()
    labels = batch['label'].long().cuda().contiguous()
    attention_mask = batch['padding_mask'].float().cuda().contiguous()
    abstract_ids = batch['uid'].long().cuda().contiguous()
    if args.fp16:
        attention_mask = attention_mask.half()

    return tokens, types, labels, attention_mask
    return tokens, types, labels, attention_mask, abstract_ids


def cross_entropy_loss_func(labels, output_tensor):
@@ -79,7 +80,7 @@ def _cross_entropy_forward_step(batch, model):
        batch_ = next(batch)
    except BaseException:
        batch_ = batch
    tokens, types, labels, attention_mask = process_batch(batch_)
    tokens, types, labels, attention_mask, abstract_ids = process_batch(batch_)
    timers('batch-generator').stop()

    # Forward model.
@@ -162,7 +163,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset,


def _train(model, optimizer, lr_scheduler, forward_step,
           train_dataloader, valid_dataloader, end_of_epoch_callback):
           train_dataloader, valid_dataloader, end_of_epoch_callback, end_of_training_eval_callback, end_of_training_f1_callback):
    """Train the model."""
    args = get_args()
    timers = get_timers()
@@ -254,11 +255,19 @@ def _train(model, optimizer, lr_scheduler, forward_step,
        if end_of_epoch_callback is not None:
            end_of_epoch_callback(model, epoch)
    
    if end_of_training_eval_callback is not None:
        end_of_training_eval_callback(model, args.epochs)

    if end_of_training_f1_callback is not None:
        end_of_training_f1_callback(model, args.epochs)


def finetune(train_valid_datasets_provider, model_provider,
             model_type=ModelType.encoder_or_decoder,
             forward_step=_cross_entropy_forward_step,
             end_of_epoch_callback_provider=None,
             end_of_training_eval_callback_provider=None,
             end_of_training_f1_callback_provider=None,
             task_collate_fn=None):
    """Main finetune function used across all tasks."""
    args = get_args()
@@ -282,6 +291,14 @@ def finetune(train_valid_datasets_provider, model_provider,
    end_of_epoch_callback = None
    if end_of_epoch_callback_provider is not None:
        end_of_epoch_callback = end_of_epoch_callback_provider()

    end_of_training_f1_callback = None
    if end_of_training_f1_callback_provider is not None:
        end_of_training_f1_callback = end_of_training_f1_callback_provider()

    end_of_training_eval_callback = None
    if end_of_training_eval_callback_provider is not None:
        end_of_training_eval_callback = end_of_training_eval_callback_provider()
    timers('callback function').stop()

    # Build model, optimizer and learning rate scheduler.
@@ -315,7 +332,7 @@ def finetune(train_valid_datasets_provider, model_provider,
    # Finetune the model.
    if args.epochs > 0:
        _train(model, optimizer, lr_scheduler, forward_step,
               train_dataloader, valid_dataloader, end_of_epoch_callback)
               train_dataloader, valid_dataloader, end_of_epoch_callback, end_of_training_eval_callback, end_of_training_f1_callback)
    # Or just evaluate.
    else:
        if end_of_epoch_callback is not None: