Loading megatron/schedules.py +7 −2 Original line number Diff line number Diff line Loading @@ -362,6 +362,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite num_microbatches_remaining = \ num_microbatches - num_warmup_microbatches # Measure pipeline stall only if there are enough microbatches # to have every worker in a warmup and steady state phase. measure_pipeline_stall = get_num_microbatches() >= \ mpu.get_pipeline_model_parallel_world_size() input_tensors = [] output_tensors = [] losses_reduced = [] Loading @@ -372,7 +377,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) # Barrier before first receive to measure forward stall. if i == (num_warmup_microbatches - 1): if i == (num_warmup_microbatches - 1) and measure_pipeline_stall: timers('forward-pipeline-stall').start() torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group()) timers('forward-pipeline-stall').stop() Loading @@ -382,7 +387,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite output_tensors.append(output_tensor) # Barrier before first receive to measure forward stall. if num_warmup_microbatches == 0: if num_warmup_microbatches == 0 and measure_pipeline_stall: timers('forward-pipeline-stall').start() torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group()) timers('forward-pipeline-stall').stop() Loading Loading
megatron/schedules.py +7 −2 Original line number Diff line number Diff line Loading @@ -362,6 +362,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite num_microbatches_remaining = \ num_microbatches - num_warmup_microbatches # Measure pipeline stall only if there are enough microbatches # to have every worker in a warmup and steady state phase. measure_pipeline_stall = get_num_microbatches() >= \ mpu.get_pipeline_model_parallel_world_size() input_tensors = [] output_tensors = [] losses_reduced = [] Loading @@ -372,7 +377,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) # Barrier before first receive to measure forward stall. if i == (num_warmup_microbatches - 1): if i == (num_warmup_microbatches - 1) and measure_pipeline_stall: timers('forward-pipeline-stall').start() torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group()) timers('forward-pipeline-stall').stop() Loading @@ -382,7 +387,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite output_tensors.append(output_tensor) # Barrier before first receive to measure forward stall. if num_warmup_microbatches == 0: if num_warmup_microbatches == 0 and measure_pipeline_stall: timers('forward-pipeline-stall').start() torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group()) timers('forward-pipeline-stall').stop() Loading