Commit b2f57fc4 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

pulled latest megatron

parents a4b628ab 76e3fca0
Loading
Loading
Loading
Loading
+37 −4
Original line number Diff line number Diff line
@@ -70,7 +70,7 @@ def parse_args(extra_args_provider=None, defaults={},
    model_parallel_size = args.pipeline_model_parallel_size * \
                          args.tensor_model_parallel_size
    assert args.world_size % model_parallel_size == 0, 'world size is not'\
        ' divisible by tensor parallel size ({}) times pipeline paralle ' \
        ' divisible by tensor parallel size ({}) times pipeline parallel ' \
        'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
                           args.pipeline_model_parallel_size)
    args.data_parallel_size = args.world_size // model_parallel_size
@@ -116,6 +116,18 @@ def parse_args(extra_args_provider=None, defaults={},
            print('setting global batch size to {}'.format(
                args.global_batch_size), flush=True)
    assert args.global_batch_size > 0
    if args.num_layers_per_virtual_pipeline_stage is not None:
        assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
            'number of layers is not divisible by number of layers per virtual ' \
            'pipeline stage'
        args.virtual_pipeline_model_parallel_size = \
            (args.num_layers // args.pipeline_model_parallel_size) // \
            args.num_layers_per_virtual_pipeline_stage
        assert args.global_batch_size % args.pipeline_model_parallel_size == 0, \
            'global batch size is not divisible by pipeline parallel size when ' \
            'using interleaved schedule'
    else:
        args.virtual_pipeline_model_parallel_size = None

    # Parameters dtype.
    args.params_dtype = torch.float
@@ -203,6 +215,22 @@ def parse_args(extra_args_provider=None, defaults={},
            'for distribute-checkpointed-activations to work you '\
            'need to enable checkpoint-activations'

    # custom kernel constraints check
    seq_len = args.seq_length
    attn_batch_size = \
        (args.num_attention_heads / args.tensor_model_parallel_size) * \
        args.micro_batch_size

    # constraints on sequence length and attn_batch_size to enable warp based
    # optimization and upper triangular optimization (for causal mask)
    custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \
        seq_len % 4 == 0 and attn_batch_size % 4 == 0

    if args.fp16 and custom_kernel_constraint and args.masked_softmax_fusion:
        print('WARNING: constraints for invoking optimized'
            ' fused softmax kernel are not met. We default back to unfused'
            ' kernel invocations.')

    # Load scaled_masked_softmax_fusion_kernels
    if args.masked_softmax_fusion:
        fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()
@@ -478,9 +506,9 @@ def _add_checkpointing_args(parser):
                       help='Output directory to save checkpoints to.')
    group.add_argument('--save-interval', type=int, default=None,
                       help='Number of iterations between checkpoint saves.')
    group.add_argument('--no-save-optim', action='store_true',
    group.add_argument('--no-save-optim', action='store_true', default=None,
                       help='Do not save current optimizer.')
    group.add_argument('--no-save-rng', action='store_true',
    group.add_argument('--no-save-rng', action='store_true', default=None,
                       help='Do not save current rng state.')
    group.add_argument('--load', type=str, default=None,
                       help='Directory containing a model checkpoint.')
@@ -541,6 +569,8 @@ def _add_distributed_args(parser):
    group.add_argument('--model-parallel-size', type=int, default=None,
                       help='Old model parallel argument, do not use. Use '
                       '--tensor-model-parallel-size instead.')
    group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
                       help='Number of layers per virtual pipeline stage')
    group.add_argument('--distributed-backend', default='nccl',
                       choices=['nccl', 'gloo'],
                       help='Which backend to use for distributed training.')
@@ -548,6 +578,9 @@ def _add_distributed_args(parser):
                       choices=['local', 'torch'],
                       help='which DistributedDataParallel implementation '
                       'to use.')
    group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
                       help='Use scatter/gather to optimize communication of tensors in pipeline',
                       dest='scatter_gather_tensors_in_pipeline')
    group.add_argument('--local_rank', type=int, default=None,
                       help='local rank passed from distributed launcher.')
    group.add_argument('--lazy-mpu-init', type=bool, required=False,
+23 −12
Original line number Diff line number Diff line
@@ -21,12 +21,12 @@ import sys
import numpy as np

import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP

from megatron import (get_args,
                      mpu,
                      print_rank_0,
                      update_num_microbatches)
                      update_num_microbatches,
                      utils)

_CHECKPOINT_VERSION = None

@@ -111,8 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
    args = get_args()

    # Only rank zero of the data parallel writes to the disk.
    if isinstance(model, torchDDP):
        model = model.module
    model = utils.unwrap_model(model)

    print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
        iteration, args.save))
@@ -124,7 +123,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
        state_dict['args'] = args
        state_dict['checkpoint_version'] = 3.0
        state_dict['iteration'] = iteration
        state_dict['model'] = model.state_dict_for_save_checkpoint()
        if len(model) == 1:
            state_dict['model'] = model[0].state_dict_for_save_checkpoint()
        else:
            for i in range(len(model)):
                mpu.set_virtual_pipeline_model_parallel_rank(i)
                state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint()

        # Optimizer stuff.
        if not args.no_save_optim:
@@ -238,8 +242,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
    args = get_args()
    load_dir = getattr(args, load_arg)

    if isinstance(model, torchDDP):
        model = model.module
    model = utils.unwrap_model(model)

    # Read the tracker file and set the iteration.
    tracker_filename = get_checkpoint_tracker_filename(load_dir)

@@ -324,7 +328,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
        print_rank_0('could not find arguments in the checkpoint ...')

    # Model.
    model.load_state_dict(state_dict['model'], strict=strict)
    if len(model) == 1:
        model[0].load_state_dict(state_dict['model'], strict=strict)
    else:
        for i in range(len(model)):
            mpu.set_virtual_pipeline_model_parallel_rank(i)
            model[i].load_state_dict(state_dict['model%d' % i], strict=strict)

    # Fix up query/key/value matrix ordering if needed
    checkpoint_version = get_checkpoint_version()
@@ -352,12 +361,15 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
            np.random.set_state(state_dict['np_rng_state'])
            torch.set_rng_state(state_dict['torch_rng_state'])
            torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
            # Check for empty states array
            if not state_dict['rng_tracker_states']:
                raise KeyError
            mpu.get_cuda_rng_tracker().set_states(
                state_dict['rng_tracker_states'])
        except KeyError:
            print_rank_0('Unable to load optimizer from checkpoint {}. '
            print_rank_0('Unable to load rng state from checkpoint {}. '
                         'Specify --no-load-rng or --finetune to prevent '
                         'attempting to load the optimizer state, '
                         'attempting to load the rng state, '
                         'exiting ...'.format(checkpoint_name))
            sys.exit()

@@ -376,8 +388,7 @@ def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, f

    args = get_args()

    if isinstance(model, torchDDP):
        model = model.module
    model = utils.unwrap_model(model)

    load_path = args.load if from_realm_chkpt else args.ict_load

+2 −1
Original line number Diff line number Diff line
@@ -133,7 +133,8 @@ def _initialize_distributed():
            print('model parallel is already initialized')
        else:
            mpu.initialize_model_parallel(args.tensor_model_parallel_size,
                                          args.pipeline_model_parallel_size)
                                          args.pipeline_model_parallel_size,
                                          args.virtual_pipeline_model_parallel_size)


def _init_autoresume():
+10 −5
Original line number Diff line number Diff line
@@ -116,15 +116,20 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
 
    def forward(self, input, mask):
        # [b, np, sq, sk]
        assert input.dim() == 4
        data_size = input.size()
        query_seq_len = data_size[-2]
        key_seq_len = data_size[-1]
        assert input.dim() == 4
        attn_batch_size = data_size[0] * data_size[1]

        # invoke custom kernel
        if self.input_in_fp16 and key_seq_len <= 2048 and mask is not None and \
           query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion:
        # constraints on various tensor dimensions to enable warp based
        # optimization and upper triangular optimization (for causal mask)
        custom_kernel_constraint = key_seq_len > 16 and key_seq_len <= 2048 and \
            query_seq_len % 4 == 0 and attn_batch_size % 4 == 0

        # invoke custom kernel
        if self.input_in_fp16 and mask is not None and \
           custom_kernel_constraint and self.scaled_masked_softmax_fusion:
            scale = self.scale if self.scale is not None else 1.0

            if self.attn_mask_type == AttnMaskType.causal:
+2 −2
Original line number Diff line number Diff line
@@ -50,9 +50,9 @@ class MegatronModule(torch.nn.Module):


    def word_embeddings_weight(self):
        if mpu.is_pipeline_first_stage():
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            return self.language_model.embedding.word_embeddings.weight
        if mpu.is_pipeline_last_stage():
        if mpu.is_pipeline_last_stage(ignore_virtual=True):
            if not self.share_word_embeddings:
                raise Exception('word_embeddings_weight() called for last '
                                'stage, but share_word_embeddings is false')
Loading