Loading megatron/p2p_communication.py +44 −12 Original line number Diff line number Diff line Loading @@ -22,7 +22,9 @@ from megatron import mpu def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, use_ring_exchange=False): use_ring_exchange=False, tensor_shape=None, override_scatter_gather_tensors_in_pipeline=False, dtype_=None): """Communicate tensors between stages. Used as helper method in other communication methods that are used in megatron/schedules.py. Loading @@ -37,7 +39,14 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, next rank. use_ring_exchange: boolean for whether torch.distributed.ring_exchange() API should be used. tensor_shape: optional, use when the input sequence contains less tokens than the default sequence length override_scatter_gather_tensors_in_pipeline: optional, this is used when tensor_shape is provided to overwide scatter gather tensors dtype_: optional, this is used when tensor_shape is provied and what is the type of tensor_shape Returns: (tensor_recv_prev, tensor_recv_next) """ Loading @@ -47,8 +56,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, # if needed. tensor_recv_prev = None tensor_recv_next = None if tensor_shape is None: tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) if args.scatter_gather_tensors_in_pipeline: if not override_scatter_gather_tensors_in_pipeline and \ args.scatter_gather_tensors_in_pipeline: tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \ mpu.get_tensor_model_parallel_world_size() else: Loading @@ -56,19 +67,26 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, dtype = args.params_dtype if args.fp32_residual_connection: dtype = torch.float requires_grad = True if dtype_ is not None: dtype = dtype_ requires_grad = False if recv_prev: tensor_recv_prev = torch.empty(tensor_chunk_shape, requires_grad=True, requires_grad=requires_grad, device=torch.cuda.current_device(), dtype=dtype) if recv_next: tensor_recv_next = torch.empty(tensor_chunk_shape, requires_grad=True, requires_grad=requires_grad, device=torch.cuda.current_device(), dtype=dtype) # Split tensor into smaller chunks if using scatter-gather optimization. if args.scatter_gather_tensors_in_pipeline: if not override_scatter_gather_tensors_in_pipeline and \ args.scatter_gather_tensors_in_pipeline: if tensor_send_next is not None: tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next) Loading Loading @@ -112,7 +130,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, torch.cuda.synchronize() # If using scatter-gather optimization, gather smaller chunks. if args.scatter_gather_tensors_in_pipeline: if not override_scatter_gather_tensors_in_pipeline and \ args.scatter_gather_tensors_in_pipeline: if recv_prev: tensor_recv_prev = mpu.gather_split_1d_tensor( tensor_recv_prev).view(tensor_shape).requires_grad_() Loading @@ -124,8 +143,11 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, return tensor_recv_prev, tensor_recv_next def recv_forward(timers=None): def recv_forward(tensor_shape=None, override_scatter_gather_tensors_in_pipeline=False, dtype_=None, timers=None): """Receive tensor from previous rank in pipeline (forward receive).""" if mpu.is_pipeline_first_stage(): input_tensor = None else: Loading @@ -135,7 +157,11 @@ def recv_forward(timers=None): tensor_send_next=None, tensor_send_prev=None, recv_prev=True, recv_next=False) recv_next=False, tensor_shape=tensor_shape, override_scatter_gather_tensors_in_pipeline=\ override_scatter_gather_tensors_in_pipeline, dtype_=dtype_) if timers is not None: timers('forward-recv').stop() return input_tensor Loading @@ -158,8 +184,11 @@ def recv_backward(timers=None): return output_tensor_grad def send_forward(output_tensor, timers=None): def send_forward(output_tensor, timers=None, override_scatter_gather_tensors_in_pipeline=False, dtype_=None): """Send tensor to next rank in pipeline (forward send).""" if not mpu.is_pipeline_last_stage(): if timers is not None: timers('forward-send').start() Loading @@ -167,7 +196,10 @@ def send_forward(output_tensor, timers=None): tensor_send_next=output_tensor, tensor_send_prev=None, recv_prev=False, recv_next=False) recv_next=False, override_scatter_gather_tensors_in_pipeline=\ override_scatter_gather_tensors_in_pipeline, dtype_=dtype_) if timers is not None: timers('forward-send').stop() Loading Loading
megatron/p2p_communication.py +44 −12 Original line number Diff line number Diff line Loading @@ -22,7 +22,9 @@ from megatron import mpu def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, use_ring_exchange=False): use_ring_exchange=False, tensor_shape=None, override_scatter_gather_tensors_in_pipeline=False, dtype_=None): """Communicate tensors between stages. Used as helper method in other communication methods that are used in megatron/schedules.py. Loading @@ -37,7 +39,14 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, next rank. use_ring_exchange: boolean for whether torch.distributed.ring_exchange() API should be used. tensor_shape: optional, use when the input sequence contains less tokens than the default sequence length override_scatter_gather_tensors_in_pipeline: optional, this is used when tensor_shape is provided to overwide scatter gather tensors dtype_: optional, this is used when tensor_shape is provied and what is the type of tensor_shape Returns: (tensor_recv_prev, tensor_recv_next) """ Loading @@ -47,8 +56,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, # if needed. tensor_recv_prev = None tensor_recv_next = None if tensor_shape is None: tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) if args.scatter_gather_tensors_in_pipeline: if not override_scatter_gather_tensors_in_pipeline and \ args.scatter_gather_tensors_in_pipeline: tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \ mpu.get_tensor_model_parallel_world_size() else: Loading @@ -56,19 +67,26 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, dtype = args.params_dtype if args.fp32_residual_connection: dtype = torch.float requires_grad = True if dtype_ is not None: dtype = dtype_ requires_grad = False if recv_prev: tensor_recv_prev = torch.empty(tensor_chunk_shape, requires_grad=True, requires_grad=requires_grad, device=torch.cuda.current_device(), dtype=dtype) if recv_next: tensor_recv_next = torch.empty(tensor_chunk_shape, requires_grad=True, requires_grad=requires_grad, device=torch.cuda.current_device(), dtype=dtype) # Split tensor into smaller chunks if using scatter-gather optimization. if args.scatter_gather_tensors_in_pipeline: if not override_scatter_gather_tensors_in_pipeline and \ args.scatter_gather_tensors_in_pipeline: if tensor_send_next is not None: tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next) Loading Loading @@ -112,7 +130,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, torch.cuda.synchronize() # If using scatter-gather optimization, gather smaller chunks. if args.scatter_gather_tensors_in_pipeline: if not override_scatter_gather_tensors_in_pipeline and \ args.scatter_gather_tensors_in_pipeline: if recv_prev: tensor_recv_prev = mpu.gather_split_1d_tensor( tensor_recv_prev).view(tensor_shape).requires_grad_() Loading @@ -124,8 +143,11 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, return tensor_recv_prev, tensor_recv_next def recv_forward(timers=None): def recv_forward(tensor_shape=None, override_scatter_gather_tensors_in_pipeline=False, dtype_=None, timers=None): """Receive tensor from previous rank in pipeline (forward receive).""" if mpu.is_pipeline_first_stage(): input_tensor = None else: Loading @@ -135,7 +157,11 @@ def recv_forward(timers=None): tensor_send_next=None, tensor_send_prev=None, recv_prev=True, recv_next=False) recv_next=False, tensor_shape=tensor_shape, override_scatter_gather_tensors_in_pipeline=\ override_scatter_gather_tensors_in_pipeline, dtype_=dtype_) if timers is not None: timers('forward-recv').stop() return input_tensor Loading @@ -158,8 +184,11 @@ def recv_backward(timers=None): return output_tensor_grad def send_forward(output_tensor, timers=None): def send_forward(output_tensor, timers=None, override_scatter_gather_tensors_in_pipeline=False, dtype_=None): """Send tensor to next rank in pipeline (forward send).""" if not mpu.is_pipeline_last_stage(): if timers is not None: timers('forward-send').start() Loading @@ -167,7 +196,10 @@ def send_forward(output_tensor, timers=None): tensor_send_next=output_tensor, tensor_send_prev=None, recv_prev=False, recv_next=False) recv_next=False, override_scatter_gather_tensors_in_pipeline=\ override_scatter_gather_tensors_in_pipeline, dtype_=dtype_) if timers is not None: timers('forward-send').stop() Loading