Commit 7eb70084 authored by Isaac's avatar Isaac
Browse files

Merge remote-tracking branch 'origin/F1STAT' into main

parents 9d86ca67 0aa86dc1
Loading
Loading
Loading
Loading

finetune_hoc.lsf

0 → 100644
+69 −0
Original line number Diff line number Diff line
#!/bin/bash

#BSUB -nnodes 2
#BSUB -W 2:00
#BSUB -P med106
#BSUB -alloc_flags "smt4 nvme"
#BSUB -J hoc_FULL
#BSUB -o hoc_FULL.%J
#BSUB -q batch

set +x

#module load open-ce/1.4.0-py38-0
module load open-ce
conda deactivate
conda activate /gpfs/alpine/med106/world-shared/irl1/rhel8/mytorch
module list

export OMP_NUM_THREADS=1
#export PYTHONPATH=$PYTHONPATH:/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/megatron/fused_kernels
#export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/gpfs/alpine/med106/world-shared/irl1/rhel8/mytorch/lib/python3.8/site-packages/torch/lib
#export PATH=$PATH:/gpfs/alpine/med106/world-shared/irl1/rhel8/mytorch/lib/python3.8/site-packages/torch/include

nodes=($(cat ${LSB_DJOB_HOSTFILE} | sort | uniq | grep -v login | grep -v batch))
nnodes=${#nodes[@]}
echo $nnodes

#export TRAIN_DATA=/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/picodata/train.tsv
#TRAIN_DATA="/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/hocdata/train.tsv"
TRAIN_DATA="/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/hocdata"
#export TRAIN_DATA=picodata/train.tsv
#export VALID_DATA=/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/picodata/dev.tsv
#VALID_DATA="/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/hocdata/dev.tsv"
VALID_DATA="/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/hocdata"
#export VALID_DATA=picodata/dev.tsv

export VOCAB_FILE=/gpfs/alpine/world-shared/med106/g8o/pubmed_bert-vocab.txt
export CHECKPOINT_PATH=/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/finetune-HOC_BIG
export PRETRAINED_CHECKPOINT=/gpfs/alpine/med106/world-shared/irl1/rhel8/fork-megatron/chkptt

jsrun --smpiargs="-disable_gpu_hooks" -n $nnodes -r 1 -g 6 -a 6 -c 42 python tasks/main.py \
       --task HOC \
       --tensor-model-parallel-size 2 \
       --pipeline-model-parallel-size 2 \
       --num-layers 24 \
       --hidden-size 1024 \
       --num-attention-heads 16 \
       --seq-length 512 \
       --max-position-embeddings 512 \
       --fp16 \
       --vocab-file $VOCAB_FILE \
       --train-data $TRAIN_DATA \
       --valid-data $VALID_DATA \
       --pretrained-checkpoint $PRETRAINED_CHECKPOINT \
       --activations-checkpoint-method uniform \
       --save-interval 10000 \
       --save $CHECKPOINT_PATH \
       --log-interval 100 \
       --eval-interval 1000 \
       --eval-iters 10 \
       --weight-decay 1e-2 \
       --tokenizer-type BertWordPieceLowerCase \
       --epochs 5 \
       --micro-batch-size 4 \
       --lr 0.0001 \
       --lr-warmup-fraction 0.06 \
       --distributed-backend nccl
       #--DDP-impl torch \
+7 −3
Original line number Diff line number Diff line
@@ -19,6 +19,8 @@ import argparse
import os

import torch
from mpi4py import MPI
import subprocess

def parse_args(extra_args_provider=None, defaults={},
               ignore_unknown_args=False):
@@ -54,8 +56,10 @@ def parse_args(extra_args_provider=None, defaults={},
        args = parser.parse_args()

    # Distributed args.
    args.rank = int(os.getenv('RANK', '0'))
    args.world_size = int(os.getenv("WORLD_SIZE", '1'))
    comm = MPI.COMM_WORLD
    args.rank = comm.Get_rank()
    args.world_size = comm.Get_size()

    # Tensor model parallel size.
    args.tensor_model_parallel_size = min(
        args.tensor_model_parallel_size, args.world_size)
@@ -494,7 +498,7 @@ def _add_training_args(parser):
                       help='Disable bias and dropout fusion.',
                       dest='bias_dropout_fusion')
    group.add_argument('--optimizer', type=str, default='adam',
                       choices=['adam', 'sgd'],
                       choices=['adam', 'sgd', 'lamb'],
                       help='Optimizer function')
    group.add_argument('--dataloader-type', type=str, default=None,
                       choices=['single', 'cyclic'],
+9 −5
Original line number Diff line number Diff line
@@ -442,11 +442,15 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
    valid_datasets = []
    test_datasets = []
    for i in range(len(prefixes)):
        train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
            prefixes[i], data_impl, splits_string,
        train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(prefixes[i], 
                                                                       data_impl, splits_string,
                                                                       datasets_train_valid_test_num_samples[i],
            max_seq_length, masked_lm_prob, short_seq_prob,
            seed, skip_warmup, binary_head, dataset_type=dataset_type)
                                                                       max_seq_length, masked_lm_prob, 
                                                                       short_seq_prob, seed, 
                                                                       skip_warmup, 
                                                                       binary_head, 
                                                                       max_seq_length_dec,
                                                                       dataset_type=dataset_type)
        if train_ds:
            train_datasets.append(train_ds)
        if valid_ds:
+23 −2
Original line number Diff line number Diff line
@@ -32,6 +32,9 @@ from megatron.global_vars import set_global_variables
from megatron.mpu import (set_tensor_model_parallel_rank,
                          set_tensor_model_parallel_world_size)

import os
from mpi4py import MPI
import subprocess

def initialize_megatron(extra_args_provider=None, args_defaults={},
                        ignore_unknown_args=False, allow_no_cuda=False):
@@ -177,10 +180,28 @@ def _initialize_distributed():
                args.local_rank = device
            torch.cuda.set_device(device)
    # Call the init process
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    world_size = comm.Get_size()
    master_addr = None
    if rank == 0:
        hostname_cmd = ["hostname -I"]
        result = subprocess.check_output(hostname_cmd, shell=True)
        master_addr = result.decode('utf-8').split()[0]
    master_addr = comm.bcast(master_addr, root=0)
    proc_name = MPI.Get_processor_name()
    all_procs = comm.allgather(proc_name)
    local_rank = sum([i == proc_name for i in all_procs[:rank]])
    os.environ['RANK'] = str(rank)
    os.environ['WORLD_SIZE'] = str(world_size)
    os.environ['LOCAL_RANK'] = str(local_rank)
    os.environ['MASTER_ADDR'] = master_addr
    os.environ['MASTER_PORT'] = str(29500)
    init_method=None
    torch.distributed.init_process_group(
        backend=args.distributed_backend,
        world_size=args.world_size, rank=args.rank,
        timeout=timedelta(minutes=10))
        timeout=timedelta(minutes=10),
        init_method=init_method)

    # Set the tensor model-parallel, pipeline model-parallel, and
    # data-parallel communicators.
+139 −0
Original line number Diff line number Diff line
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Classification model."""

import torch

from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule


class Classification_hoc(MegatronModule):

    def __init__(self,
                 num_classes,
                 num_tokentypes=2,
                 pre_process=True,
                 post_process=True):
        super(Classification_hoc, self).__init__(share_word_embeddings=False)
        args = get_args()

        self.num_classes = num_classes
        self.pre_process = pre_process
        self.post_process = post_process
        init_method = init_method_normal(args.init_method_std)

        self.language_model, self._language_model_key = get_language_model(
            num_tokentypes=num_tokentypes,
            add_pooler=True,
            encoder_attn_mask_type=AttnMaskType.padding,
            init_method=init_method,
            scaled_init_method=scaled_init_method_normal(args.init_method_std,
                                                         args.num_layers),
            pre_process=self.pre_process,
            post_process=self.post_process)

        # Multi-choice head.
        if self.post_process:
            self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
            self.classification_head0 = get_linear_layer(args.hidden_size,
                                                        args.hidden_size,
                                                        init_method)
            self.classification_head1 = get_linear_layer(args.hidden_size,
                                                        2*self.num_classes,
                                                        init_method)
            self._classification_head_key0 = 'classification_head0'
            self._classification_head_key1 = 'classification_head1'

    def set_input_tensor(self, input_tensor):
        """See megatron.model.transformer.set_input_tensor()"""
        self.language_model.set_input_tensor(input_tensor)

    def forward(self, model_input, attention_mask, tokentype_ids=None):

        extended_attention_mask = bert_extended_attention_mask(attention_mask)
        input_ids = model_input
        position_ids = bert_position_ids(input_ids)

        lm_output = self.language_model(
            input_ids,
            position_ids,
            extended_attention_mask,
            tokentype_ids=tokentype_ids
        )

        if self.post_process:
            _, pooled_output = lm_output
            x = self.classification_dropout(pooled_output)
            x = self.classification_head0(x)
            x = torch.nn.ReLU()(x)
            x = self.classification_head1(x)


            # Reshape back to separate choices.
            #classification_logits = classification_logits.view(-1, self.num_classes)
            x = torch.reshape(x, (x.shape[0], self.num_classes, 2))

            #return classification_logits
            return x
        return lm_output

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)
        if self.post_process:
            state_dict_[self._classification_head_key0] \
                = self.classification_head0.state_dict(
                    destination, prefix, keep_vars)
            state_dict_[self._classification_head_key1] \
                = self.classification_head1.state_dict(
                    destination, prefix, keep_vars)
        return state_dict_

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
        if self.post_process:
            if self._classification_head_key0 in state_dict:
                self.classification_head0.load_state_dict(
                    state_dict[self._classification_head_key0], strict=strict)
            else:
                print_rank_last('***WARNING*** could not find {} in the checkpoint, '
                                'initializing to random'.format(
                                    self._classification_head_key0))

            if self._classification_head_key1 in state_dict:
                self.classification_head1.load_state_dict(
                    state_dict[self._classification_head_key1], strict=strict)
            else:
                print_rank_last('***WARNING*** could not find {} in the checkpoint, '
                                'initializing to random'.format(
                                    self._classification_head_key1))
Loading