Loading megatron/p2p_communication.py +38 −56 Original line number Diff line number Diff line Loading @@ -21,8 +21,7 @@ from megatron import get_args from megatron import mpu def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, use_ring_exchange=False): def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next): """Communicate tensors between stages. Used as helper method in other communication methods that are used in megatron/schedules.py. Loading @@ -35,8 +34,6 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, previous rank. recv_next: boolean for whether tensor should be received from next rank. use_ring_exchange: boolean for whether torch.distributed.ring_exchange() API should be used. Returns: (tensor_recv_prev, tensor_recv_next) Loading Loading @@ -76,13 +73,6 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev) # Send tensors in both the forward and backward directions as appropriate. if use_ring_exchange: torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev, tensor_recv_prev=tensor_recv_prev, tensor_send_next=tensor_send_next, tensor_recv_next=tensor_recv_next, group=mpu.get_pipeline_model_parallel_group()) else: ops = [] if tensor_send_prev is not None: send_prev_op = torch.distributed.P2POp( Loading @@ -104,6 +94,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, torch.distributed.irecv, tensor_recv_next, mpu.get_pipeline_model_parallel_next_rank()) ops.append(recv_next_op) if len(ops) > 0: reqs = torch.distributed.batch_isend_irecv(ops) for req in reqs: req.wait() Loading @@ -123,7 +114,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, return tensor_recv_prev, tensor_recv_next def recv_forward(timers=None, use_ring_exchange=False): def recv_forward(timers=None): """Receive tensor from previous rank in pipeline (forward receive).""" if mpu.is_pipeline_first_stage(): input_tensor = None Loading @@ -134,14 +125,13 @@ def recv_forward(timers=None, use_ring_exchange=False): tensor_send_next=None, tensor_send_prev=None, recv_prev=True, recv_next=False, use_ring_exchange=use_ring_exchange) recv_next=False) if timers is not None: timers('forward-recv').stop() return input_tensor def recv_backward(timers=None, use_ring_exchange=False): def recv_backward(timers=None): """Receive tensor from next rank in pipeline (backward receive).""" if mpu.is_pipeline_last_stage(): output_tensor_grad = None Loading @@ -152,14 +142,13 @@ def recv_backward(timers=None, use_ring_exchange=False): tensor_send_next=None, tensor_send_prev=None, recv_prev=False, recv_next=True, use_ring_exchange=use_ring_exchange) recv_next=True) if timers is not None: timers('backward-recv').stop() return output_tensor_grad def send_forward(output_tensor, timers=None, use_ring_exchange=False): def send_forward(output_tensor, timers=None): """Send tensor to next rank in pipeline (forward send).""" if not mpu.is_pipeline_last_stage(): if timers is not None: Loading @@ -168,13 +157,12 @@ def send_forward(output_tensor, timers=None, use_ring_exchange=False): tensor_send_next=output_tensor, tensor_send_prev=None, recv_prev=False, recv_next=False, use_ring_exchange=use_ring_exchange) recv_next=False) if timers is not None: timers('forward-send').stop() def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False): def send_backward(input_tensor_grad, timers=None): """Send tensor to previous rank in pipeline (backward send).""" if not mpu.is_pipeline_first_stage(): if timers is not None: Loading @@ -183,13 +171,12 @@ def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False): tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=False, recv_next=False, use_ring_exchange=use_ring_exchange) recv_next=False) if timers is not None: timers('backward-send').stop() def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=False): def send_forward_recv_backward(output_tensor, timers=None): """Batched send and recv with next rank in pipeline.""" if mpu.is_pipeline_last_stage(): output_tensor_grad = None Loading @@ -200,14 +187,13 @@ def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=Fal tensor_send_next=output_tensor, tensor_send_prev=None, recv_prev=False, recv_next=True, use_ring_exchange=use_ring_exchange) recv_next=True) if timers is not None: timers('forward-send-backward-recv').stop() return output_tensor_grad def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange=False): def send_backward_recv_forward(input_tensor_grad, timers=None): """Batched send and recv with previous rank in pipeline.""" if mpu.is_pipeline_first_stage(): input_tensor = None Loading @@ -218,8 +204,7 @@ def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=True, recv_next=False, use_ring_exchange=use_ring_exchange) recv_next=False) if timers is not None: timers('backward-send-forward-recv').stop() return input_tensor Loading @@ -233,8 +218,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None): tensor_send_next=output_tensor, tensor_send_prev=None, recv_prev=recv_prev, recv_next=False, use_ring_exchange=True) recv_next=False) if timers is not None: timers('forward-send-forward-recv').stop() return input_tensor Loading @@ -248,8 +232,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None): tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=False, recv_next=recv_next, use_ring_exchange=True) recv_next=recv_next) if timers is not None: timers('backward-send-backward-recv').stop() return output_tensor_grad Loading @@ -265,8 +248,7 @@ def send_forward_backward_recv_forward_backward( tensor_send_next=output_tensor, tensor_send_prev=input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, use_ring_exchange=True) recv_next=recv_next) if timers is not None: timers('forward-backward-send-forward-backward-recv').stop() return input_tensor, output_tensor_grad megatron/schedules.py +2 −2 Original line number Diff line number Diff line Loading @@ -210,7 +210,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat # Run warmup forward passes. mpu.set_virtual_pipeline_model_parallel_rank(0) input_tensors[0].append( p2p_communication.recv_forward(timers, use_ring_exchange=True)) p2p_communication.recv_forward(timers)) for k in range(num_warmup_microbatches): output_tensor = forward_step_helper(k) Loading Loading @@ -322,7 +322,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat if not forward_only: if all_warmup_microbatches: output_tensor_grads[num_model_chunks-1].append( p2p_communication.recv_backward(timers, use_ring_exchange=True)) p2p_communication.recv_backward(timers)) for k in range(num_microbatches_remaining, num_microbatches): input_tensor_grad = backward_step_helper(k) next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False) Loading Loading
megatron/p2p_communication.py +38 −56 Original line number Diff line number Diff line Loading @@ -21,8 +21,7 @@ from megatron import get_args from megatron import mpu def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, use_ring_exchange=False): def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next): """Communicate tensors between stages. Used as helper method in other communication methods that are used in megatron/schedules.py. Loading @@ -35,8 +34,6 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, previous rank. recv_next: boolean for whether tensor should be received from next rank. use_ring_exchange: boolean for whether torch.distributed.ring_exchange() API should be used. Returns: (tensor_recv_prev, tensor_recv_next) Loading Loading @@ -76,13 +73,6 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev) # Send tensors in both the forward and backward directions as appropriate. if use_ring_exchange: torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev, tensor_recv_prev=tensor_recv_prev, tensor_send_next=tensor_send_next, tensor_recv_next=tensor_recv_next, group=mpu.get_pipeline_model_parallel_group()) else: ops = [] if tensor_send_prev is not None: send_prev_op = torch.distributed.P2POp( Loading @@ -104,6 +94,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, torch.distributed.irecv, tensor_recv_next, mpu.get_pipeline_model_parallel_next_rank()) ops.append(recv_next_op) if len(ops) > 0: reqs = torch.distributed.batch_isend_irecv(ops) for req in reqs: req.wait() Loading @@ -123,7 +114,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, return tensor_recv_prev, tensor_recv_next def recv_forward(timers=None, use_ring_exchange=False): def recv_forward(timers=None): """Receive tensor from previous rank in pipeline (forward receive).""" if mpu.is_pipeline_first_stage(): input_tensor = None Loading @@ -134,14 +125,13 @@ def recv_forward(timers=None, use_ring_exchange=False): tensor_send_next=None, tensor_send_prev=None, recv_prev=True, recv_next=False, use_ring_exchange=use_ring_exchange) recv_next=False) if timers is not None: timers('forward-recv').stop() return input_tensor def recv_backward(timers=None, use_ring_exchange=False): def recv_backward(timers=None): """Receive tensor from next rank in pipeline (backward receive).""" if mpu.is_pipeline_last_stage(): output_tensor_grad = None Loading @@ -152,14 +142,13 @@ def recv_backward(timers=None, use_ring_exchange=False): tensor_send_next=None, tensor_send_prev=None, recv_prev=False, recv_next=True, use_ring_exchange=use_ring_exchange) recv_next=True) if timers is not None: timers('backward-recv').stop() return output_tensor_grad def send_forward(output_tensor, timers=None, use_ring_exchange=False): def send_forward(output_tensor, timers=None): """Send tensor to next rank in pipeline (forward send).""" if not mpu.is_pipeline_last_stage(): if timers is not None: Loading @@ -168,13 +157,12 @@ def send_forward(output_tensor, timers=None, use_ring_exchange=False): tensor_send_next=output_tensor, tensor_send_prev=None, recv_prev=False, recv_next=False, use_ring_exchange=use_ring_exchange) recv_next=False) if timers is not None: timers('forward-send').stop() def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False): def send_backward(input_tensor_grad, timers=None): """Send tensor to previous rank in pipeline (backward send).""" if not mpu.is_pipeline_first_stage(): if timers is not None: Loading @@ -183,13 +171,12 @@ def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False): tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=False, recv_next=False, use_ring_exchange=use_ring_exchange) recv_next=False) if timers is not None: timers('backward-send').stop() def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=False): def send_forward_recv_backward(output_tensor, timers=None): """Batched send and recv with next rank in pipeline.""" if mpu.is_pipeline_last_stage(): output_tensor_grad = None Loading @@ -200,14 +187,13 @@ def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=Fal tensor_send_next=output_tensor, tensor_send_prev=None, recv_prev=False, recv_next=True, use_ring_exchange=use_ring_exchange) recv_next=True) if timers is not None: timers('forward-send-backward-recv').stop() return output_tensor_grad def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange=False): def send_backward_recv_forward(input_tensor_grad, timers=None): """Batched send and recv with previous rank in pipeline.""" if mpu.is_pipeline_first_stage(): input_tensor = None Loading @@ -218,8 +204,7 @@ def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=True, recv_next=False, use_ring_exchange=use_ring_exchange) recv_next=False) if timers is not None: timers('backward-send-forward-recv').stop() return input_tensor Loading @@ -233,8 +218,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None): tensor_send_next=output_tensor, tensor_send_prev=None, recv_prev=recv_prev, recv_next=False, use_ring_exchange=True) recv_next=False) if timers is not None: timers('forward-send-forward-recv').stop() return input_tensor Loading @@ -248,8 +232,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None): tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=False, recv_next=recv_next, use_ring_exchange=True) recv_next=recv_next) if timers is not None: timers('backward-send-backward-recv').stop() return output_tensor_grad Loading @@ -265,8 +248,7 @@ def send_forward_backward_recv_forward_backward( tensor_send_next=output_tensor, tensor_send_prev=input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, use_ring_exchange=True) recv_next=recv_next) if timers is not None: timers('forward-backward-send-forward-backward-recv').stop() return input_tensor, output_tensor_grad
megatron/schedules.py +2 −2 Original line number Diff line number Diff line Loading @@ -210,7 +210,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat # Run warmup forward passes. mpu.set_virtual_pipeline_model_parallel_rank(0) input_tensors[0].append( p2p_communication.recv_forward(timers, use_ring_exchange=True)) p2p_communication.recv_forward(timers)) for k in range(num_warmup_microbatches): output_tensor = forward_step_helper(k) Loading Loading @@ -322,7 +322,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat if not forward_only: if all_warmup_microbatches: output_tensor_grads[num_model_chunks-1].append( p2p_communication.recv_backward(timers, use_ring_exchange=True)) p2p_communication.recv_backward(timers)) for k in range(num_microbatches_remaining, num_microbatches): input_tensor_grad = backward_step_helper(k) next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False) Loading