Commit 87273694 authored by Isaac's avatar Isaac
Browse files

Adding Blurb HOC Finetune task

parent f187f762
Loading
Loading
Loading
Loading

finetune_hoc.lsf

0 → 100644
+69 −0
Original line number Diff line number Diff line
#!/bin/bash

#BSUB -nnodes 2
#BSUB -W 0:30
#BSUB -P med106
#BSUB -alloc_flags "smt4 nvme"
#BSUB -J hoc
#BSUB -o hoc.%J
#BSUB -q batch

set +x

#module load open-ce/1.4.0-py38-0
module load open-ce
conda deactivate
conda activate /gpfs/alpine/med106/world-shared/irl1/rhel8/mytorch
module list

export OMP_NUM_THREADS=1
#export PYTHONPATH=$PYTHONPATH:/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/megatron/fused_kernels
#export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/gpfs/alpine/med106/world-shared/irl1/rhel8/mytorch/lib/python3.8/site-packages/torch/lib
#export PATH=$PATH:/gpfs/alpine/med106/world-shared/irl1/rhel8/mytorch/lib/python3.8/site-packages/torch/include

nodes=($(cat ${LSB_DJOB_HOSTFILE} | sort | uniq | grep -v login | grep -v batch))
nnodes=${#nodes[@]}
echo $nnodes

#export TRAIN_DATA=/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/picodata/train.tsv
#TRAIN_DATA="/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/hocdata/train.tsv"
TRAIN_DATA="/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/hocdata_small"
#export TRAIN_DATA=picodata/train.tsv
#export VALID_DATA=/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/picodata/dev.tsv
#VALID_DATA="/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/hocdata/dev.tsv"
VALID_DATA="/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/hocdata_small"
#export VALID_DATA=picodata/dev.tsv

export VOCAB_FILE=/gpfs/alpine/world-shared/med106/g8o/pubmed_bert-vocab.txt
export CHECKPOINT_PATH=/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/chkpt_222
export PRETRAINED_CHECKPOINT=/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/chkptt

jsrun --smpiargs="-disable_gpu_hooks" -n $nnodes -r 1 -g 6 -a 6 -c 42 python tasks/main.py \
       --task HOC \
       --tensor-model-parallel-size 2 \
       --pipeline-model-parallel-size 2 \
       --num-layers 24 \
       --hidden-size 1024 \
       --num-attention-heads 16 \
       --seq-length 512 \
       --max-position-embeddings 512 \
       --fp16 \
       --vocab-file $VOCAB_FILE \
       --train-data $TRAIN_DATA \
       --valid-data $VALID_DATA \
       --pretrained-checkpoint $PRETRAINED_CHECKPOINT \
       --activations-checkpoint-method uniform \
       --save-interval 10000 \
       --save $CHECKPOINT_PATH \
       --log-interval 100 \
       --eval-interval 1000 \
       --eval-iters 10 \
       --weight-decay 1e-2 \
       --tokenizer-type BertWordPieceLowerCase \
       --epochs 1 \
       --micro-batch-size 4 \
       --lr 0.0001 \
       --lr-warmup-fraction 0.06 \
       --distributed-backend nccl
       #--DDP-impl torch \
+139 −0
Original line number Diff line number Diff line
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Classification model."""

import torch

from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule


class Classification_hoc(MegatronModule):

    def __init__(self,
                 num_classes,
                 num_tokentypes=2,
                 pre_process=True,
                 post_process=True):
        super(Classification_hoc, self).__init__(share_word_embeddings=False)
        args = get_args()

        self.num_classes = num_classes
        self.pre_process = pre_process
        self.post_process = post_process
        init_method = init_method_normal(args.init_method_std)

        self.language_model, self._language_model_key = get_language_model(
            num_tokentypes=num_tokentypes,
            add_pooler=True,
            encoder_attn_mask_type=AttnMaskType.padding,
            init_method=init_method,
            scaled_init_method=scaled_init_method_normal(args.init_method_std,
                                                         args.num_layers),
            pre_process=self.pre_process,
            post_process=self.post_process)

        # Multi-choice head.
        if self.post_process:
            self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
            self.classification_head0 = get_linear_layer(args.hidden_size,
                                                        args.hidden_size,
                                                        init_method)
            self.classification_head1 = get_linear_layer(args.hidden_size,
                                                        2*self.num_classes,
                                                        init_method)
            self._classification_head_key0 = 'classification_head0'
            self._classification_head_key1 = 'classification_head1'

    def set_input_tensor(self, input_tensor):
        """See megatron.model.transformer.set_input_tensor()"""
        self.language_model.set_input_tensor(input_tensor)

    def forward(self, model_input, attention_mask, tokentype_ids=None):

        extended_attention_mask = bert_extended_attention_mask(attention_mask)
        input_ids = model_input
        position_ids = bert_position_ids(input_ids)

        lm_output = self.language_model(
            input_ids,
            position_ids,
            extended_attention_mask,
            tokentype_ids=tokentype_ids
        )

        if self.post_process:
            _, pooled_output = lm_output
            x = self.classification_dropout(pooled_output)
            x = self.classification_head0(x)
            x = torch.nn.ReLU()(x)
            x = self.classification_head1(x)


            # Reshape back to separate choices.
            #classification_logits = classification_logits.view(-1, self.num_classes)
            x = torch.reshape(x, (x.shape[0], self.num_classes, 2))

            #return classification_logits
            return x
        return lm_output

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)
        if self.post_process:
            state_dict_[self._classification_head_key0] \
                = self.classification_head0.state_dict(
                    destination, prefix, keep_vars)
            state_dict_[self._classification_head_key1] \
                = self.classification_head1.state_dict(
                    destination, prefix, keep_vars)
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
        if self.post_process:
            if self._classification_head_key0 in state_dict:
                self.classification_head0.load_state_dict(
                    state_dict[self._classification_head_key0], strict=strict)
            else:
                print_rank_last('***WARNING*** could not find {} in the checkpoint, '
                                'initializing to random'.format(
                                    self._classification_head_key0))

            if self._classification_head_key1 in state_dict:
                self.classification_head1.load_state_dict(
                    state_dict[self._classification_head_key1], strict=strict)
            else:
                print_rank_last('***WARNING*** could not find {} in the checkpoint, '
                                'initializing to random'.format(
                                    self._classification_head_key1))
+109 −0
Original line number Diff line number Diff line

import glob
import os
import time

import torch
from torch.utils.data import Dataset
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import get_args
from tasks.data_utils import build_sample
from tasks.data_utils import build_sample_hoc
from tasks.data_utils import build_tokens_types_paddings_from_ids
from tasks.data_utils import build_tokens_types_paddings_from_text
from tasks.data_utils import clean_text

from transformers import AutoTokenizer
from transformers import BertTokenizerFast
from transformers import PreTrainedTokenizerFast
from pathlib import Path
import re
import numpy as np


class HOCDataset(Dataset):

    def __init__(self, dataset_name, datapaths, tokenizer, max_seq_length,ignore_index=-100, tasks=['I-PAR', 'I-INT', 'I-OUT']):
        args = get_args()
        
        self.dataset_name = dataset_name
        print_rank_0(' > building HOC dataset for {}:'.format(
            self.dataset_name))

        string = '  > paths:'
        for path in datapaths:
            string += ' ' + path
        print_rank_0(string)

        #HFTokenizer = BertTokenizerFast(args.vocab_file)
        MegatronTokenizer = tokenizer

        self.samples = []
        for datapath in datapaths:
            self.samples.extend(process_single_datapath(datapath, MegatronTokenizer, max_seq_length, self.dataset_name))

        print_rank_0('  >> total number of samples: {}'.format(
            len(self.samples)))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

def _read_hoc(file_path,dataset_name):
    fp = str(file_path)

    filenames = glob.glob(os.path.join(fp, '*.tsv'))


    data_x = []
    data_y = []
    abstract_ids = []
    for filename in filenames:
        fn_str = str(filename)
        if dataset_name == fn_str.split('/')[-1].split('.')[0]:
            with open(filename, 'r') as f:
                rowCounter = 0
                for row in f:
                    if rowCounter != 0:
                        labelSentenceIndex = row.split('\t')
                        data_x.append(labelSentenceIndex[1])
                        labels = labelSentenceIndex[0].split(',')
                        labels = [int(x.split('_')[1]) for x in labels]
                        data_y.append(labels)
                        abstract = labelSentenceIndex[-1].split('_')[0]
                        abstract_ids.append(abstract)

                    rowCounter += 1
        else:
            continue

    return data_x, data_y

def process_single_datapath(datapath, MegatronTokenizer, max_seq_length, dataset_name):

    print_rank_0('   > working on {}'.format(datapath))
    start_time = time.time()
    data_x, data_y = _read_hoc(datapath,dataset_name)

    samples = []
    num_samples = 0
    for i in range(len(data_x)):
        data_str = str(data_x[i])
        data_str.strip("[").strip("]")
        context = clean_text(data_str)
        no_context = None
        #Tokenize data
        ids, types, paddings = build_tokens_types_paddings_from_text(
            context, no_context, MegatronTokenizer,  max_seq_length)
        label = data_y[i]
        samples.append(build_sample_hoc(ids,types,paddings,label,num_samples))
        num_samples += 1
    
    elapsed_time = time.time() - start_time
    print_rank_0('    > processed {} samples'
                 ' in {:.2f} seconds'.format(num_samples, elapsed_time))
    return samples
+225 −0
Original line number Diff line number Diff line
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Evaluation utilities."""

import os
import time
from functools import partial

import torch
import numpy as np

from megatron import get_args
from megatron import print_rank_last, is_last_rank
from megatron import mpu
from megatron.schedules import get_forward_backward_func
from tasks.blurb.hoc.finetune_utils import build_data_loader
from tasks.blurb.hoc.finetune_utils import process_batch
from megatron.utils import average_losses_across_data_parallel_group


def accuracy_func_provider(single_dataset_provider):
    """Provide function that calculates accuracies."""
    args = get_args()

    # Build dataloaders.
    datapaths = args.valid_data
    dataloaders = []
    for datapath in datapaths:
        dataset = single_dataset_provider(datapath)
        dataloader = build_data_loader(
            dataset, args.orig_micro_batch_size, num_workers=args.num_workers,
            drop_last=(mpu.get_data_parallel_world_size() > 1))
        dataloaders.append((dataset.dataset_name, dataloader))

    def metrics_func(model, epoch, output_predictions=False):
        print_rank_last('calculating metrics ...')
        num_classes=10
        correct = np.zeros(num_classes, dtype=int)
        total = 0
        if output_predictions:
            assert mpu.get_data_parallel_world_size() == 1
            named_predictions = []
            names = 'predictions'
        
        for name, dataloader in dataloaders:
            output = calculate_correct_answers(name, model, dataloader,
                                               epoch, output_predictions)
            if not output_predictions:
                correct_ans, total_count = output
            else:
                correct_ans, total_count, predictions = output
                named_predictions.append((name, predictions))
                names += '_' + name
            if mpu.is_pipeline_last_stage():
            #if is_last_rank():
                for i in range(num_classes):
                    correct[i] += correct_ans[i]
                total += total_count
        if is_last_rank():
            for i in range(num_classes):
                percent = float(correct[i]) * 100.0 / float(total)
                print(' >> |epoch: {}| overall: correct / total = {} / {} = '
                    '{:.4f} %'.format(epoch, correct[i], total, percent))

        if output_predictions and is_last_rank():
            assert args.load is not None
            filename = os.path.join(args.load, names + '.pt')
            torch.save(named_predictions, filename)

    return metrics_func


def calculate_correct_answers(name, model, dataloader,
                              epoch, output_predictions):
    """Calculate correct over total answers and return prediction if the
    `output_predictions` is true."""
    args = get_args()
    forward_backward_func = get_forward_backward_func()
    start_time = time.time()
    for m in model:
        m.eval()
    saved_micro_batch_size = args.micro_batch_size
    saved_global_batch_size = args.global_batch_size

    ds = dataloader.dataset
    if hasattr(ds, 'sample_multiplier'):
        # If our dataset as a sample_multiplier attribute that means
        # each "sample" from the dataset actually has multiple samples
        # that will collapse into the batch dimension (for example in
        # the RACE dataset that has several options), we need to
        # account for that when setting the micro batch size.
        sample_multiplier = ds.sample_multiplier
    else:
        sample_multiplier = 1
    micro_batch_size_times_data_parallel = args.orig_micro_batch_size * args.data_parallel_size
    num_micro_batches = args.orig_global_batch_size // micro_batch_size_times_data_parallel

    #def loss_func(output_predictions, labels, output_tensor, bs):
    def loss_func(output_predictions, labels, output_tensor):

        loss_fcn = torch.nn.CrossEntropyLoss()
        num_classes = 10
        loss = None
        correct = np.zeros(num_classes)
        loss_dict = {}
        for i in range(num_classes):
            if loss is None:
                loss = loss_fcn(output_tensor[:,i,:],labels[:,i])
            else:
                loss += loss_fcn(output_tensor[:,i,:],labels[:,i])


            predicted = torch.argmax(output_tensor[:,i,:], dim=-1)
            corrects = (predicted == labels[:,i])

            loss_dict['correct{%d}' % i] = corrects.sum().item()

        loss_dict['total'] = labels.size(dim=0)
        #loss_dict['total'] = bs

        #averaged_loss = average_losses_across_data_parallel_group([loss])
        #return loss, {'lm loss': averaged_loss[0]}, loss_dict
        return 0, loss_dict

    # defined inside to capture output_predictions
    def correct_answers_forward_step(batch, model):
        try:
            batch_ = next(batch)
        except BaseException:
            batch_ = batch
        tokens, types, labels, attention_mask = process_batch(batch_)

        # Forward model.
        args = get_args()
        output_tensor = model(tokens, attention_mask, tokentype_ids=types)

        #bs = len(batch['label'])
        #return output_tensor, partial(loss_func, output_predictions, labels, bs)
        return output_tensor, partial(loss_func, output_predictions, labels)

    num_classes = 10
    with torch.no_grad():
        # For all the batches in the dataset.
        total = 0
        correct = np.zeros(num_classes, dtype=int)
        if output_predictions:
            # This option is only possible when data parallel size is 1.
            assert mpu.get_data_parallel_world_size() == 1
            softmaxes = []
            labels = []
            ids = []
        for _, batch in enumerate(dataloader):
            # For evaluation only mode we use drop_last = False to get all the
            # samples, which means we might not have a full batch, so we
            # adjust batch_size here to actual batch size of data

            # ... applying sample_multiplier if necessary
            actual_batch_size = len(batch['label'])
            args.micro_batch_size = actual_batch_size * sample_multiplier
            args.global_batch_size = actual_batch_size * sample_multiplier * num_micro_batches

            loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
                                               optimizer=None, timers=None, forward_only=True)

            for loss_dict in loss_dicts:
                if output_predictions:
                    softmaxes.extend(loss_dict['softmaxes'])
                    labels.extend(loss_dict['labels'])
                    ids.extend(loss_dict['ids'])

                total += loss_dict['total']
                correct[0] += loss_dict['correct{0}']
                correct[1] += loss_dict['correct{1}']
                correct[2] += loss_dict['correct{2}']
                correct[3] += loss_dict['correct{3}']
                correct[4] += loss_dict['correct{4}']
                correct[5] += loss_dict['correct{5}']
                correct[6] += loss_dict['correct{6}']
                correct[7] += loss_dict['correct{7}']
                correct[8] += loss_dict['correct{8}']
                correct[9] += loss_dict['correct{9}']
                #for i in range(num_classes):
                #    correct[i] += loss_dict['correct{%d}' % i]

    for m in model:
        m.train()
    args.micro_batch_size = saved_micro_batch_size
    args.global_batch_size = saved_global_batch_size

    # Reduce.
    if mpu.is_pipeline_last_stage():
        correct_ans = np.zeros(num_classes,dtype=int)
        for i in range(num_classes):
            unreduced = torch.cuda.LongTensor([correct[i], total])
            torch.distributed.all_reduce(unreduced,
                                        group=mpu.get_data_parallel_group())
            # Print on screen.
            correct_ans[i] = unreduced[0].item()
            total_count = unreduced[1].item()
            percent = float(correct_ans[i]) * 100.0 / float(total_count)
            elapsed_time = time.time() - start_time
            print_rank_last(' > |epoch: {}| metrics for {}: correct / total '
                            '= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format(
                                epoch, name, correct_ans[i], total_count,
                                percent, elapsed_time))

        if output_predictions:
            return correct_ans, total_count, (softmaxes, labels, ids)
        return correct_ans, total_count
    if output_predictions:
        return 0, 0, ()
    return 0, 0
+61 −0
Original line number Diff line number Diff line
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""GLUE finetuning/evaluation."""

from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron.model.classification_hoc import Classification_hoc
from tasks.blurb.hoc.eval_utils import accuracy_func_provider
from tasks.blurb.hoc.finetune_utils import finetune
from tasks.blurb.hoc.data import HOCDataset


def train_valid_datasets_provider():
    """Build train and validation dataset."""
    args = get_args()
    tokenizer = get_tokenizer()

    train_dataset = HOCDataset('train', args.train_data,
                            tokenizer, args.seq_length)
    valid_dataset = HOCDataset('dev', args.valid_data,
                            tokenizer, args.seq_length)

    return train_dataset, valid_dataset

def model_provider(pre_process=True, post_process=True):
    """Build the model."""
    args = get_args()

    print_rank_0('building classification model for {} ...')
    model = Classification_hoc(num_classes=10, num_tokentypes=2,
                           pre_process=pre_process, post_process=post_process)

    return model

def metrics_func_provider():
    """Privde metrics callback function."""
    args = get_args()
    tokenizer = get_tokenizer()

    def single_dataset_provider(datapath):
        return HOCDataset('dev', [datapath], tokenizer, args.seq_length)
    return accuracy_func_provider(single_dataset_provider)

def main():
    finetune(train_valid_datasets_provider, model_provider,
             end_of_epoch_callback_provider=metrics_func_provider)
Loading