Commit dd889062 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Interleaved pipeline execution and code refactoring

- Split a model's computation into multiple virtual stages as needed,
and schedule communication correctly between these virtual stages
- Move schedule code into `schedules.py` and communication code into
`p2p_communication.py`
- Use hyphens instead of spaces in all time logging for consistency
- Factor out code in megatron/training.py into helper functions
- Refactor evaluate() function: make it use forward_backward_schedule
functions
parent c1faa9fe
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -557,6 +557,8 @@ def _add_distributed_args(parser):
    group.add_argument('--model-parallel-size', type=int, default=None,
                       help='Old model parallel argument, do not use. Use '
                       '--tensor-model-parallel-size instead.')
    group.add_argument('--virtual-pipeline-model-parallel-size', type=int, default=None,
                       help='Number of virtual pipeline stages in physical stage.')
    group.add_argument('--distributed-backend', default='nccl',
                       choices=['nccl', 'gloo'],
                       help='Which backend to use for distributed training.')
+25 −6
Original line number Diff line number Diff line
@@ -111,8 +111,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
    args = get_args()

    # Only rank zero of the data parallel writes to the disk.
    if isinstance(model, torchDDP):
        model = model.module
    unwrapped_model = []
    for model_module in model:
        if isinstance(model_module, torchDDP):
            model_module = model_module.module
        unwrapped_model.append(model_module)
    model = unwrapped_model

    print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
        iteration, args.save))
@@ -124,7 +128,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
        state_dict['args'] = args
        state_dict['checkpoint_version'] = 3.0
        state_dict['iteration'] = iteration
        state_dict['model'] = model.state_dict_for_save_checkpoint()
        if len(model) == 1:
            state_dict['model'] = model[0].state_dict_for_save_checkpoint()
        else:
            for i in range(len(model)):
                mpu.set_virtual_pipeline_model_parallel_rank(i)
                state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint()

        # Optimizer stuff.
        if not args.no_save_optim:
@@ -211,8 +220,13 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
    args = get_args()
    load_dir = getattr(args, load_arg)

    if isinstance(model, torchDDP):
        model = model.module
    unwrapped_model = []
    for model_module in model:
        if isinstance(model_module, torchDDP):
            model_module = model_module.module
        unwrapped_model.append(model_module)
    model = unwrapped_model

    # Read the tracker file and set the iteration.
    tracker_filename = get_checkpoint_tracker_filename(load_dir)

@@ -297,7 +311,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
        print_rank_0('could not find arguments in the checkpoint ...')

    # Model.
    model.load_state_dict(state_dict['model'], strict=strict)
    if len(model) == 1:
        model[0].load_state_dict(state_dict['model'], strict=strict)
    else:
        for i in range(len(model)):
            mpu.set_virtual_pipeline_model_parallel_rank(i)
            model[i].load_state_dict(state_dict['model%d' % i], strict=strict)

    # Fix up query/key/value matrix ordering
    if get_checkpoint_version() < 2.0:
+2 −1
Original line number Diff line number Diff line
@@ -133,7 +133,8 @@ def _initialize_distributed():
            print('model parallel is already initialized')
        else:
            mpu.initialize_model_parallel(args.tensor_model_parallel_size,
                                          args.pipeline_model_parallel_size)
                                          args.pipeline_model_parallel_size,
                                          args.virtual_pipeline_model_parallel_size)


def _init_autoresume():
+2 −2
Original line number Diff line number Diff line
@@ -50,9 +50,9 @@ class MegatronModule(torch.nn.Module):


    def word_embeddings_weight(self):
        if mpu.is_pipeline_first_stage():
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            return self.language_model.embedding.word_embeddings.weight
        if mpu.is_pipeline_last_stage():
        if mpu.is_pipeline_last_stage(ignore_virtual=True):
            if not self.share_word_embeddings:
                raise Exception('word_embeddings_weight() called for last '
                                'stage, but share_word_embeddings is false')
+9 −1
Original line number Diff line number Diff line
@@ -552,6 +552,14 @@ class ParallelTransformer(MegatronModule):
                layer_number,
                layer_type=layer_type,
                self_attn_mask_type=self_attn_mask_type)
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
                'num_layers_per_stage must be divisible by virtual_pipeline_model_parallel_size'
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
                    args.num_layers // args.virtual_pipeline_model_parallel_size) + \
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
        else:
            offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
        self.layers = torch.nn.ModuleList(
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])
Loading