Commit 8e922d5b authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Put in barriers in appropriate places to measure length of pipeline stall

parent dd889062
Loading
Loading
Loading
Loading
+15 −2
Original line number Diff line number Diff line
@@ -95,6 +95,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
        output_tensor_grads = [[] for _ in range(len(model))]

    pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
    pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()

    # Compute number of warmup and remaining microbatches.
    num_model_chunks = len(model)
@@ -108,8 +109,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
            all_warmup_microbatches = True
        else:
            num_warmup_microbatches = \
                (pipeline_parallel_size -
                 mpu.get_pipeline_model_parallel_rank() - 1) * 2
                (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
            num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
            num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
    num_microbatches_remaining = \
@@ -272,6 +272,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
def forward_backward_pipelining(forward_step_func, data_iterator, model,
                                optimizer, timers, forward_only):
    """Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
    timers = get_timers()

    assert len(model) == 1
    model = model[0]

@@ -295,11 +297,22 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
        input_tensor = recv_forward(timers)
        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
        # Barrier before first receive to measure forward stall.
        if i == (num_warmup_microbatches - 1):
            timers('forward-pipeline-stall').start()
            torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
            timers('forward-pipeline-stall').stop()
        send_forward(output_tensor, timers)

        input_tensors.append(input_tensor)
        output_tensors.append(output_tensor)

    # Barrier before first receive to measure forward stall.
    if num_warmup_microbatches == 0:
        timers('forward-pipeline-stall').start()
        torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
        timers('forward-pipeline-stall').stop()

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
+7 −0
Original line number Diff line number Diff line
@@ -354,6 +354,11 @@ def train_step(forward_step_func, data_iterator,
                                          fp32_allreduce=args.fp32_allreduce)
        timers('backward-params-all-reduce').stop()

    # Barrier to measure backward stall.
    timers('backward-pipeline-stall').start()
    torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
    timers('backward-pipeline-stall').stop()

    # All-reduce word_embeddings' grad across first and last stages to ensure
    # that word_embeddings parameters stay in sync.
    # This should only run for models that support pipelined model parallelism
@@ -443,10 +448,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
        if name in timers.timers:
            timers_to_log.append(name)
    add_to_logging('forward-compute')
    add_to_logging('forward-pipeline-stall')
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
    add_to_logging('forward-backward-send-forward-backward-recv')
    add_to_logging('backward-compute')
    add_to_logging('backward-pipeline-stall')
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
    add_to_logging('backward-send-forward-recv')