Loading megatron/model/transformer.py +20 −11 Original line number Diff line number Diff line Loading @@ -557,7 +557,6 @@ class ParallelTransformer(MegatronModule): self.pre_process = pre_process self.post_process = post_process self.input_tensor = None self.micro_batch_size = args.micro_batch_size # Store activation checkpoiting flag. self.activations_checkpoint_method = args.activations_checkpoint_method Loading Loading @@ -696,22 +695,32 @@ class ParallelTransformer(MegatronModule): # See set_input_tensor() hidden_states = self.input_tensor # Viewless tensor # We only need to create a viewless tensor in the case of micro batch # Viewless tensor. # - We only need to create a viewless tensor in the case of micro batch # size (mbs) == 1, since in this case, 'hidden_states.transpose()' # above creates a view tensor, and '.contiguous()' is a pass-through. # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating # the need to make it viewless. if self.micro_batch_size == 1: # # However, we don't explicitly check mbs == 1 here because # make_viewless_tensor() has negligible overhead when its input # is already viewless. # # - For the 'else' case above, calling make_viewless_tensor() here is # likely redundant, since p2p_communication.py (likely originator) # already creates viewless tensors. That said, make_viewless_tensor() # is called here to be future-proof and corner-case-proof. hidden_states = mpu.make_viewless_tensor( hidden_states, requires_grad = True, keep_graph = True, ) # Transpose encoder output. if encoder_output is not None: encoder_output = encoder_output.transpose(0, 1).contiguous() # Forward pass. if self.activations_checkpoint_method is not None: hidden_states = self._checkpointed_forward(hidden_states, attention_mask, Loading Loading
megatron/model/transformer.py +20 −11 Original line number Diff line number Diff line Loading @@ -557,7 +557,6 @@ class ParallelTransformer(MegatronModule): self.pre_process = pre_process self.post_process = post_process self.input_tensor = None self.micro_batch_size = args.micro_batch_size # Store activation checkpoiting flag. self.activations_checkpoint_method = args.activations_checkpoint_method Loading Loading @@ -696,22 +695,32 @@ class ParallelTransformer(MegatronModule): # See set_input_tensor() hidden_states = self.input_tensor # Viewless tensor # We only need to create a viewless tensor in the case of micro batch # Viewless tensor. # - We only need to create a viewless tensor in the case of micro batch # size (mbs) == 1, since in this case, 'hidden_states.transpose()' # above creates a view tensor, and '.contiguous()' is a pass-through. # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating # the need to make it viewless. if self.micro_batch_size == 1: # # However, we don't explicitly check mbs == 1 here because # make_viewless_tensor() has negligible overhead when its input # is already viewless. # # - For the 'else' case above, calling make_viewless_tensor() here is # likely redundant, since p2p_communication.py (likely originator) # already creates viewless tensors. That said, make_viewless_tensor() # is called here to be future-proof and corner-case-proof. hidden_states = mpu.make_viewless_tensor( hidden_states, requires_grad = True, keep_graph = True, ) # Transpose encoder output. if encoder_output is not None: encoder_output = encoder_output.transpose(0, 1).contiguous() # Forward pass. if self.activations_checkpoint_method is not None: hidden_states = self._checkpointed_forward(hidden_states, attention_mask, Loading