Commit c61dc22f authored by mshoeybi's avatar mshoeybi
Browse files

some cleanup

parent b8940b96
Loading
Loading
Loading
Loading
+4 −2
Original line number Diff line number Diff line
@@ -240,10 +240,12 @@ def parse_args(extra_args_provider=None, defaults={},
            'residual connection in fp32 only supported when using fp16 or bf16.'
    # Activation checkpointing.
    if args.distribute_checkpointed_activations:
        assert args.tensor_model_parallel_size > 1
        assert args.tensor_model_parallel_size > 1, 'can distribute ' \
            'checkpointed activations only across tensor model ' \
            'parallel groups'
        assert args.activations_checkpoint_method is not None, \
            'for distribute-checkpointed-activations to work you '\
            'need to use a valid checkpoint-activation method (\'uniform\' or \'block\')'
            'need to use a activation-checkpoint method '

    _print_args(args)
    return args
+19 −2
Original line number Diff line number Diff line
@@ -608,6 +608,23 @@ class ParallelTransformer(MegatronModule):
                return x_
            return custom_forward

        def distribute_checkpointed_activations_helper(layer_number):
            """Distribute checkpointed activations across the tensor model
               Parallel ranks if the `distribute-checkpointed-activations
               is on and either of the following conditions is met:
                 - it is not the first layer in the in the pipeline stage.
                   The first layer is used in the pipeline parallelism 
                   and changing its shape throws error in the backward pass.
                 - we are at the first pipline stage so the input tensor is
                   not used in pipeline parallelism. Note that no pipeline
                   parallelism is a special case of this.
            """
            not_first_layer_in_pipeline_stage = (layer_number > 0)
            is_first_pipeline_stage = (
                mpu.get_pipeline_model_parallel_rank() == 0)
            return self.distribute_checkpointed_activations and \
                (not_first_layer_in_pipeline_stage or is_first_pipeline_stage)

        if self.activations_checkpoint_method == 'uniform':
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
@@ -616,7 +633,7 @@ class ParallelTransformer(MegatronModule):
            while l < self.num_layers:
                hidden_states = mpu.checkpoint(
                    custom(l, l + self.activations_checkpoint_num_layers),
                    self.distribute_checkpointed_activations and ( (l > 0) or (mpu.get_pipeline_model_parallel_rank() == 0)),
                    distribute_checkpointed_activations_helper(l),
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                l += self.activations_checkpoint_num_layers
        elif self.activations_checkpoint_method == 'block':
@@ -627,7 +644,7 @@ class ParallelTransformer(MegatronModule):
                if l < self.activations_checkpoint_num_layers:
                    hidden_states = mpu.checkpoint(
                        custom(l, l + 1),
                        self.distribute_checkpointed_activations and ( (l > 0) or (mpu.get_pipeline_model_parallel_rank() == 0)),
                        distribute_checkpointed_activations_helper(l),
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(