Commit 90623142 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'interleaved_schedule_assertion_fix' into 'main'

Fix interleaved schedule assertion

See merge request ADLR/megatron-lm!237
parents 9dc111cc b7067cbd
Loading
Loading
Loading
Loading
+0 −3
Original line number Diff line number Diff line
@@ -123,9 +123,6 @@ def parse_args(extra_args_provider=None, defaults={},
        args.virtual_pipeline_model_parallel_size = \
            (args.num_layers // args.pipeline_model_parallel_size) // \
            args.num_layers_per_virtual_pipeline_stage
        assert args.global_batch_size % args.pipeline_model_parallel_size == 0, \
            'global batch size is not divisible by pipeline parallel size when ' \
            'using interleaved schedule'
    else:
        args.virtual_pipeline_model_parallel_size = None

+3 −0
Original line number Diff line number Diff line
@@ -339,6 +339,9 @@ def train_step(forward_step_func, data_iterator,
    if mpu.get_pipeline_model_parallel_world_size() > 1:
        if args.virtual_pipeline_model_parallel_size is not None:
            forward_backward_func = forward_backward_pipelining_with_interleaving
            assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
                'number of microbatches is not divisible by pipeline-parallel ' \
                'size when using interleaved schedule'
        else:
            forward_backward_func = forward_backward_pipelining_without_interleaving
    else: