Commit d16e2a24 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

further clarified viewless tensor comment in transformer.py

parent 24369dd6
Loading
Loading
Loading
Loading
+20 −11
Original line number Diff line number Diff line
@@ -557,7 +557,6 @@ class ParallelTransformer(MegatronModule):
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
        self.micro_batch_size = args.micro_batch_size

        # Store activation checkpoiting flag.
        self.activations_checkpoint_method = args.activations_checkpoint_method
@@ -696,22 +695,32 @@ class ParallelTransformer(MegatronModule):
            # See set_input_tensor()
            hidden_states = self.input_tensor

        # Viewless tensor
        #   We only need to create a viewless tensor in the case of micro batch
        # Viewless tensor.
        # - We only need to create a viewless tensor in the case of micro batch
        #   size (mbs) == 1, since in this case, 'hidden_states.transpose()'
        #   above creates a view tensor, and '.contiguous()' is a pass-through.
        #   For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
        #   the need to make it viewless.
        if self.micro_batch_size == 1:
        #
        #   However, we don't explicitly check mbs == 1 here because
        #   make_viewless_tensor() has negligible overhead when its input
        #   is already viewless.
        # 
        # - For the 'else' case above, calling make_viewless_tensor() here is
        #   likely redundant, since p2p_communication.py (likely originator)
        #   already creates viewless tensors. That said, make_viewless_tensor()
        #   is called here to be future-proof and corner-case-proof.
        hidden_states = mpu.make_viewless_tensor(
            hidden_states,
            requires_grad = True,
            keep_graph = True,
        )

        # Transpose encoder output.
        if encoder_output is not None:
            encoder_output = encoder_output.transpose(0, 1).contiguous()

        # Forward pass.
        if self.activations_checkpoint_method is not None:
            hidden_states = self._checkpointed_forward(hidden_states,
                                                       attention_mask,