Commit 24369dd6 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

limit 'make_viewless_tensor()' to case of micro_batch_size == 1; added comment

parent 0c8e8cce
Loading
Loading
Loading
Loading
+12 −5
Original line number Diff line number Diff line
@@ -557,6 +557,7 @@ class ParallelTransformer(MegatronModule):
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
        self.micro_batch_size = args.micro_batch_size

        # Store activation checkpoiting flag.
        self.activations_checkpoint_method = args.activations_checkpoint_method
@@ -696,6 +697,12 @@ class ParallelTransformer(MegatronModule):
            hidden_states = self.input_tensor

        # Viewless tensor
        #   We only need to create a viewless tensor in the case of micro batch
        #   size (mbs) == 1, since in this case, 'hidden_states.transpose()'
        #   above creates a view tensor, and '.contiguous()' is a pass-through.
        #   For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
        #   the need to make it viewless.
        if self.micro_batch_size == 1:
            hidden_states = mpu.make_viewless_tensor(
                hidden_states,
                requires_grad = True,