Loading megatron/p2p_communication.py +36 −26 Original line number Diff line number Diff line Loading @@ -21,7 +21,8 @@ from megatron import get_args from megatron import mpu def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next): def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, use_ring_exchange=False): """Communicate tensors between stages. Used as helper method in other communication methods that are used in megatron/schedules.py. Loading @@ -34,6 +35,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next): previous rank. recv_next: boolean for whether tensor should be received from next rank. use_ring_exchange: boolean for whether torch.distributed.ring_exchange() API should be used. Returns: (tensor_recv_prev, tensor_recv_next) Loading Loading @@ -73,6 +76,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next): tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev) # Send tensors in both the forward and backward directions as appropriate. if use_ring_exchange: torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev, tensor_recv_prev=tensor_recv_prev, tensor_send_next=tensor_send_next, tensor_recv_next=tensor_recv_next, group=mpu.get_pipeline_model_parallel_group()) else: ops = [] if tensor_send_prev is not None: send_prev_op = torch.distributed.P2POp( Loading Loading
megatron/p2p_communication.py +36 −26 Original line number Diff line number Diff line Loading @@ -21,7 +21,8 @@ from megatron import get_args from megatron import mpu def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next): def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, use_ring_exchange=False): """Communicate tensors between stages. Used as helper method in other communication methods that are used in megatron/schedules.py. Loading @@ -34,6 +35,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next): previous rank. recv_next: boolean for whether tensor should be received from next rank. use_ring_exchange: boolean for whether torch.distributed.ring_exchange() API should be used. Returns: (tensor_recv_prev, tensor_recv_next) Loading Loading @@ -73,6 +76,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next): tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev) # Send tensors in both the forward and backward directions as appropriate. if use_ring_exchange: torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev, tensor_recv_prev=tensor_recv_prev, tensor_send_next=tensor_send_next, tensor_recv_next=tensor_recv_next, group=mpu.get_pipeline_model_parallel_group()) else: ops = [] if tensor_send_prev is not None: send_prev_op = torch.distributed.P2POp( Loading