Commit 5993f04b authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

removed distribute_checkpointed_activations_helper()

parent d6380fda
Loading
Loading
Loading
Loading
+2 −19
Original line number Diff line number Diff line
@@ -626,23 +626,6 @@ class ParallelTransformer(MegatronModule):
                return x_
            return custom_forward

        def distribute_checkpointed_activations_helper(layer_number):
            """Distribute checkpointed activations across the tensor model
               Parallel ranks if the `distribute-checkpointed-activations
               is on and either of the following conditions is met:
                 - it is not the first layer in the in the pipeline stage.
                   The first layer is used in the pipeline parallelism 
                   and changing its shape throws error in the backward pass.
                 - we are at the first pipline stage so the input tensor is
                   not used in pipeline parallelism. Note that no pipeline
                   parallelism is a special case of this.
            """
            not_first_layer_in_pipeline_stage = (layer_number > 0)
            is_first_pipeline_stage = (
                mpu.get_pipeline_model_parallel_rank() == 0)
            return self.distribute_checkpointed_activations and \
                (not_first_layer_in_pipeline_stage or is_first_pipeline_stage)

        if self.activations_checkpoint_method == 'uniform':
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
@@ -651,7 +634,7 @@ class ParallelTransformer(MegatronModule):
            while l < self.num_layers:
                hidden_states = mpu.checkpoint(
                    custom(l, l + self.activations_checkpoint_num_layers),
                    distribute_checkpointed_activations_helper(l),
                    self.distribute_checkpointed_activations,
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                l += self.activations_checkpoint_num_layers
        elif self.activations_checkpoint_method == 'block':
@@ -662,7 +645,7 @@ class ParallelTransformer(MegatronModule):
                if l < self.activations_checkpoint_num_layers:
                    hidden_states = mpu.checkpoint(
                        custom(l, l + 1),
                        distribute_checkpointed_activations_helper(l),
                        self.distribute_checkpointed_activations,
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(