Commit 08ddde0c authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Fix deadlock when get_num_microbatches() < pipeline-parallel size (don't try...

Fix deadlock when get_num_microbatches() < pipeline-parallel size (don't try to measure pipeline stall)
parent 854c3409
Loading
Loading
Loading
Loading
+7 −2
Original line number Diff line number Diff line
@@ -362,6 +362,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches

    # Measure pipeline stall only if there are enough microbatches
    # to have every worker in a warmup and steady state phase.
    measure_pipeline_stall = get_num_microbatches() >= \
        mpu.get_pipeline_model_parallel_world_size()

    input_tensors = []
    output_tensors = []
    losses_reduced = []
@@ -372,7 +377,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
        # Barrier before first receive to measure forward stall.
        if i == (num_warmup_microbatches - 1):
        if i == (num_warmup_microbatches - 1) and measure_pipeline_stall:
            timers('forward-pipeline-stall').start()
            torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
            timers('forward-pipeline-stall').stop()
@@ -382,7 +387,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
        output_tensors.append(output_tensor)

    # Barrier before first receive to measure forward stall.
    if num_warmup_microbatches == 0:
    if num_warmup_microbatches == 0 and measure_pipeline_stall:
        timers('forward-pipeline-stall').start()
        torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
        timers('forward-pipeline-stall').stop()