Commit 1acac4e3 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'remove_stall_timing' into 'main'

Remove pipeline stall timing

See merge request ADLR/megatron-lm!244
parents 3cbf7547 ab09d819
Loading
Loading
Loading
Loading
+0 −16
Original line number Diff line number Diff line
@@ -362,11 +362,6 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches

    # Measure pipeline stall only if there are enough microbatches
    # to have every worker in a warmup and steady state phase.
    measure_pipeline_stall = get_num_microbatches() >= \
        mpu.get_pipeline_model_parallel_world_size()

    input_tensors = []
    output_tensors = []
    losses_reduced = []
@@ -376,22 +371,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
        input_tensor = p2p_communication.recv_forward(timers)
        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
        # Barrier before first receive to measure forward stall.
        if i == (num_warmup_microbatches - 1) and measure_pipeline_stall:
            timers('forward-pipeline-stall').start()
            torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
            timers('forward-pipeline-stall').stop()
        p2p_communication.send_forward(output_tensor, timers)

        input_tensors.append(input_tensor)
        output_tensors.append(output_tensor)

    # Barrier before first receive to measure forward stall.
    if num_warmup_microbatches == 0 and measure_pipeline_stall:
        timers('forward-pipeline-stall').start()
        torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
        timers('forward-pipeline-stall').stop()

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
+0 −7
Original line number Diff line number Diff line
@@ -356,11 +356,6 @@ def train_step(forward_step_func, data_iterator,
                                          fp32_allreduce=args.fp32_allreduce)
        timers('backward-params-all-reduce').stop()

    # Barrier to measure backward stall.
    timers('backward-pipeline-stall').start()
    torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
    timers('backward-pipeline-stall').stop()

    # All-reduce word_embeddings' grad across first and last stages to ensure
    # that word_embeddings parameters stay in sync.
    # This should only run for models that support pipelined model parallelism
@@ -451,12 +446,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
        if name in timers.timers:
            timers_to_log.append(name)
    add_to_logging('forward-compute')
    add_to_logging('forward-pipeline-stall')
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
    add_to_logging('forward-backward-send-forward-backward-recv')
    add_to_logging('backward-compute')
    add_to_logging('backward-pipeline-stall')
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
    add_to_logging('backward-send-forward-recv')