Commit 76e3fca0 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'better_schedule' into 'main'

Interleaved pipelined execution and refactored schedule / communication code

See merge request ADLR/megatron-lm!189
parents c1faa9fe 72e4230b
Loading
Loading
Loading
Loading
+18 −1
Original line number Diff line number Diff line
@@ -70,7 +70,7 @@ def parse_args(extra_args_provider=None, defaults={},
    model_parallel_size = args.pipeline_model_parallel_size * \
                          args.tensor_model_parallel_size
    assert args.world_size % model_parallel_size == 0, 'world size is not'\
        ' divisible by tensor parallel size ({}) times pipeline paralle ' \
        ' divisible by tensor parallel size ({}) times pipeline parallel ' \
        'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
                           args.pipeline_model_parallel_size)
    args.data_parallel_size = args.world_size // model_parallel_size
@@ -116,6 +116,18 @@ def parse_args(extra_args_provider=None, defaults={},
            print('setting global batch size to {}'.format(
                args.global_batch_size), flush=True)
    assert args.global_batch_size > 0
    if args.num_layers_per_virtual_pipeline_stage is not None:
        assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
            'number of layers is not divisible by number of layers per virtual ' \
            'pipeline stage'
        args.virtual_pipeline_model_parallel_size = \
            (args.num_layers // args.pipeline_model_parallel_size) // \
            args.num_layers_per_virtual_pipeline_stage
        assert args.global_batch_size % args.pipeline_model_parallel_size == 0, \
            'global batch size is not divisible by pipeline parallel size when ' \
            'using interleaved schedule'
    else:
        args.virtual_pipeline_model_parallel_size = None

    # Parameters dtype.
    args.params_dtype = torch.float
@@ -557,6 +569,8 @@ def _add_distributed_args(parser):
    group.add_argument('--model-parallel-size', type=int, default=None,
                       help='Old model parallel argument, do not use. Use '
                       '--tensor-model-parallel-size instead.')
    group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
                       help='Number of layers per virtual pipeline stage')
    group.add_argument('--distributed-backend', default='nccl',
                       choices=['nccl', 'gloo'],
                       help='Which backend to use for distributed training.')
@@ -564,6 +578,9 @@ def _add_distributed_args(parser):
                       choices=['local', 'torch'],
                       help='which DistributedDataParallel implementation '
                       'to use.')
    group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
                       help='Use scatter/gather to optimize communication of tensors in pipeline',
                       dest='scatter_gather_tensors_in_pipeline')
    group.add_argument('--local_rank', type=int, default=None,
                       help='local rank passed from distributed launcher.')
    group.add_argument('--lazy-mpu-init', type=bool, required=False,
+18 −10
Original line number Diff line number Diff line
@@ -21,12 +21,12 @@ import sys
import numpy as np

import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP

from megatron import (get_args,
                      mpu,
                      print_rank_0,
                      update_num_microbatches)
                      update_num_microbatches,
                      utils)

_CHECKPOINT_VERSION = None

@@ -111,8 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
    args = get_args()

    # Only rank zero of the data parallel writes to the disk.
    if isinstance(model, torchDDP):
        model = model.module
    model = utils.unwrap_model(model)

    print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
        iteration, args.save))
@@ -124,7 +123,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
        state_dict['args'] = args
        state_dict['checkpoint_version'] = 3.0
        state_dict['iteration'] = iteration
        state_dict['model'] = model.state_dict_for_save_checkpoint()
        if len(model) == 1:
            state_dict['model'] = model[0].state_dict_for_save_checkpoint()
        else:
            for i in range(len(model)):
                mpu.set_virtual_pipeline_model_parallel_rank(i)
                state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint()

        # Optimizer stuff.
        if not args.no_save_optim:
@@ -211,8 +215,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
    args = get_args()
    load_dir = getattr(args, load_arg)

    if isinstance(model, torchDDP):
        model = model.module
    model = utils.unwrap_model(model)

    # Read the tracker file and set the iteration.
    tracker_filename = get_checkpoint_tracker_filename(load_dir)

@@ -297,7 +301,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
        print_rank_0('could not find arguments in the checkpoint ...')

    # Model.
    model.load_state_dict(state_dict['model'], strict=strict)
    if len(model) == 1:
        model[0].load_state_dict(state_dict['model'], strict=strict)
    else:
        for i in range(len(model)):
            mpu.set_virtual_pipeline_model_parallel_rank(i)
            model[i].load_state_dict(state_dict['model%d' % i], strict=strict)

    # Fix up query/key/value matrix ordering
    if get_checkpoint_version() < 2.0:
@@ -370,8 +379,7 @@ def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, f

    args = get_args()

    if isinstance(model, torchDDP):
        model = model.module
    model = utils.unwrap_model(model)

    load_path = args.load if from_realm_chkpt else args.ict_load

+2 −1
Original line number Diff line number Diff line
@@ -133,7 +133,8 @@ def _initialize_distributed():
            print('model parallel is already initialized')
        else:
            mpu.initialize_model_parallel(args.tensor_model_parallel_size,
                                          args.pipeline_model_parallel_size)
                                          args.pipeline_model_parallel_size,
                                          args.virtual_pipeline_model_parallel_size)


def _init_autoresume():
+2 −2
Original line number Diff line number Diff line
@@ -50,9 +50,9 @@ class MegatronModule(torch.nn.Module):


    def word_embeddings_weight(self):
        if mpu.is_pipeline_first_stage():
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            return self.language_model.embedding.word_embeddings.weight
        if mpu.is_pipeline_last_stage():
        if mpu.is_pipeline_last_stage(ignore_virtual=True):
            if not self.share_word_embeddings:
                raise Exception('word_embeddings_weight() called for last '
                                'stage, but share_word_embeddings is false')
+21 −1
Original line number Diff line number Diff line
@@ -552,6 +552,26 @@ class ParallelTransformer(MegatronModule):
                layer_number,
                layer_type=layer_type,
                self_attn_mask_type=self_attn_mask_type)
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
                    args.num_layers // args.virtual_pipeline_model_parallel_size) + \
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
        else:
            # Each stage gets a contiguous set of layers.
            offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
        self.layers = torch.nn.ModuleList(
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])
Loading