Loading megatron/schedules.py +29 −7 Original line number Diff line number Diff line Loading @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager import torch from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from megatron import get_args from megatron import get_num_microbatches Loading Loading @@ -74,6 +76,14 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad): return input_tensor_grad @contextmanager def dummy_handler(): try: yield finally: pass def forward_backward_no_pipelining(forward_step_func, data_iterator, model, optimizer, timers, forward_only): """Run forward and backward passes with no pipeline parallelism Loading @@ -83,15 +93,27 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model, assert len(model) == 1 model = model[0] context_handler = dummy_handler if isinstance(model, torchDDP): context_handler = model.no_sync losses_reduced = [] for i in range(get_num_microbatches()): input_tensor, output_tensor_grad = None, None with context_handler(): for i in range(get_num_microbatches() - 1): output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) if not forward_only: backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) # Run computation for last microbatch out of context handler (want to # synchronize gradients). output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) if not forward_only: backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) return losses_reduced Loading megatron/training.py +0 −2 Original line number Diff line number Diff line Loading @@ -309,8 +309,6 @@ def setup_model_and_optimizer(model_provider_func): args.iteration = 0 # We only support local DDP with multiple micro-batches. if get_num_microbatches() > 1: assert args.DDP_impl == 'local' if len(model) > 1: assert args.DDP_impl == 'local' if mpu.get_pipeline_model_parallel_world_size() > 1: Loading Loading
megatron/schedules.py +29 −7 Original line number Diff line number Diff line Loading @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager import torch from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from megatron import get_args from megatron import get_num_microbatches Loading Loading @@ -74,6 +76,14 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad): return input_tensor_grad @contextmanager def dummy_handler(): try: yield finally: pass def forward_backward_no_pipelining(forward_step_func, data_iterator, model, optimizer, timers, forward_only): """Run forward and backward passes with no pipeline parallelism Loading @@ -83,15 +93,27 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model, assert len(model) == 1 model = model[0] context_handler = dummy_handler if isinstance(model, torchDDP): context_handler = model.no_sync losses_reduced = [] for i in range(get_num_microbatches()): input_tensor, output_tensor_grad = None, None with context_handler(): for i in range(get_num_microbatches() - 1): output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) if not forward_only: backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) # Run computation for last microbatch out of context handler (want to # synchronize gradients). output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) if not forward_only: backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) return losses_reduced Loading
megatron/training.py +0 −2 Original line number Diff line number Diff line Loading @@ -309,8 +309,6 @@ def setup_model_and_optimizer(model_provider_func): args.iteration = 0 # We only support local DDP with multiple micro-batches. if get_num_microbatches() > 1: assert args.DDP_impl == 'local' if len(model) > 1: assert args.DDP_impl == 'local' if mpu.get_pipeline_model_parallel_world_size() > 1: Loading