Commit beb3e0d3 authored by Raul Puri's avatar Raul Puri
Browse files

Merge branch 'transformer_refactoring_from_pretrain_refactoring' into 'master'

Major refactoring, combining gpt2 and bert

See merge request ADLR/megatron-lm!8
parents ebbe40cd 73af1290
Loading
Loading
Loading
Loading
+1 −10
Original line number Diff line number Diff line
# ===========
# base images
# ===========
FROM nvcr.io/nvidia/pytorch:19.05-py3
FROM nvcr.io/nvidia/pytorch:19.09-py3


# ===============
@@ -27,12 +27,3 @@ RUN pip install --upgrade pip && \
COPY requirements.txt /tmp/
RUN pip install --upgrade --ignore-installed -r /tmp/requirements.txt

# ===========
# latest apex
# ===========
RUN pip uninstall -y apex && \
git clone https://github.com/NVIDIA/apex.git ~/apex && \
cd ~/apex && \
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .

docker/README.md

deleted100644 → 0
+0 −1
Original line number Diff line number Diff line
Note that as of now you need to have PySOL cloned to the directory here before building the container.
+0 −1
Original line number Diff line number Diff line
@@ -29,7 +29,6 @@ from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
from megatron.learning_rates import AnnealingLR
from megatron.model import GPT2Model
from megatron.model import gpt2_get_params_for_weight_decay_optimization
from megatron.model import DistributedDataParallel as DDP
from megatron import mpu
from apex.optimizers import FusedAdam as Adam
+13 −8
Original line number Diff line number Diff line
@@ -26,9 +26,8 @@ import argparse
import time
from arguments import get_args
from megatron.utils import Timers
from pretrain_gpt2 import initialize_distributed
from pretrain_gpt2 import set_random_seed
from pretrain_gpt2 import get_train_val_test_data
from megatron.utils import initialize_distributed
from megatron.utils import set_random_seed
from pretrain_gpt2 import get_masks_and_position_ids
from megatron.utils import load_checkpoint
from megatron.data_utils import make_tokenizer
@@ -96,7 +95,8 @@ def get_batch(context_tokens, args):
        tokens,
        args.eod_token,
        args.reset_position_ids,
        args.reset_attention_mask)
        args.reset_attention_mask,
        False)

    return tokens, attention_mask, position_ids

@@ -361,7 +361,7 @@ def switch(val1, val2, boolean):
    boolean = boolean.type_as(val1)
    return (1-boolean)*val1 + boolean*val2

def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None):
def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None, type_ids=None):
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
@@ -384,16 +384,21 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
        while context_length <= (maxlen):

            if args.recompute:
                logits = model(tokens, position_ids, attention_mask)
                logits = model(tokens, position_ids, attention_mask, tokentype_ids=type_ids)
                logits = logits[:, context_length - 1, :]
            else:
                types2use = None
                if counter == 0:
                    tokens2use = tokens[:, :context_length]
                    positions2use = position_ids[:, :context_length]
                    if type_ids is not None:
                        types2use = type_ids[:, :context_length]
                else:
                    tokens2use = tokens[:, context_length - 1].view(batch_size, -1)
                    positions2use = position_ids[:, context_length - 1].view(batch_size, -1)
                logits, layer_past = model(tokens2use, positions2use, attention_mask, layer_past=layer_past, get_present=True)
                    if type_ids is not None:
                        types2use = type_ids[:, context_length - 1].view(batch_size, -1)
                logits, layer_past = model(tokens2use, positions2use, attention_mask, layer_past=layer_past, get_key_value=True, tokentype_ids=types2use)
                logits = logits[:, -1].view(batch_size,-1).contiguous()

            if args.greedy:
+9 −1
Original line number Diff line number Diff line
@@ -22,6 +22,9 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from .loss_scaler import DynamicLossScaler, LossScaler
from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm


from megatron.module import MegatronModule

FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)

@@ -56,7 +59,7 @@ def fp16_to_fp32(val):
        return val
    return conversion_helper(val, float_conversion)

class FP16_Module(nn.Module):
class FP16_Module(MegatronModule):
    def __init__(self, module):
        super(FP16_Module, self).__init__()
        self.add_module('module', module.half())
@@ -67,6 +70,11 @@ class FP16_Module(nn.Module):
    def state_dict(self, destination=None, prefix='', keep_vars=False):
        return self.module.state_dict(destination, prefix, keep_vars)

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        return self.module.state_dict_for_save_checkpoint(destination, prefix,
                                                          keep_vars)

    def load_state_dict(self, state_dict, strict=True):
        self.module.load_state_dict(state_dict, strict=strict)

Loading