Commit 23632ee5 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'interleaved_bugfix' into 'main'

Small bugfix to make sure refactored code works with interleaved schedule

See merge request ADLR/megatron-lm!256
parents 3fc035d7 6fd78189
Loading
Loading
Loading
Loading
+6 −4
Original line number Diff line number Diff line
@@ -195,21 +195,23 @@ def get_model(model_provider_func):
    """Build the model."""
    args = get_args()

    # Build model on cpu.
    pre_process = mpu.is_pipeline_first_stage()
    post_process = mpu.is_pipeline_last_stage()

    # Build model.
    if mpu.get_pipeline_model_parallel_world_size() > 1 and \
       args.virtual_pipeline_model_parallel_size is not None:
        model = []
        for i in range(args.virtual_pipeline_model_parallel_size):
            mpu.set_virtual_pipeline_model_parallel_rank(i)
            # Set pre_process and post_process only after virtual rank is set.
            pre_process = mpu.is_pipeline_first_stage()
            post_process = mpu.is_pipeline_last_stage()
            this_model = model_provider_func(
                pre_process=pre_process,
                post_process=post_process
            )
            model.append(this_model)
    else:
        pre_process = mpu.is_pipeline_first_stage()
        post_process = mpu.is_pipeline_last_stage()
        model = model_provider_func(
            pre_process=pre_process,
            post_process=post_process