Commit b8940b96 authored by mshoeybi's avatar mshoeybi
Browse files

added for pp

parent 7f2cc3a4
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -240,10 +240,10 @@ def parse_args(extra_args_provider=None, defaults={},
            'residual connection in fp32 only supported when using fp16 or bf16.'
    # Activation checkpointing.
    if args.distribute_checkpointed_activations:
        assert args.tensor_model_parallel_size > 1
        assert args.activations_checkpoint_method is not None, \
            'for distribute-checkpointed-activations to work you '\
            'need to use a valid checkpoint-activation method (\'uniform\' or \'block\')'
        assert args.pipeline_model_parallel_size == 1

    _print_args(args)
    return args
+2 −2
Original line number Diff line number Diff line
@@ -616,7 +616,7 @@ class ParallelTransformer(MegatronModule):
            while l < self.num_layers:
                hidden_states = mpu.checkpoint(
                    custom(l, l + self.activations_checkpoint_num_layers),
                    self.distribute_checkpointed_activations,
                    self.distribute_checkpointed_activations and ( (l > 0) or (mpu.get_pipeline_model_parallel_rank() == 0)),
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                l += self.activations_checkpoint_num_layers
        elif self.activations_checkpoint_method == 'block':
@@ -627,7 +627,7 @@ class ParallelTransformer(MegatronModule):
                if l < self.activations_checkpoint_num_layers:
                    hidden_states = mpu.checkpoint(
                        custom(l, l + 1),
                        self.distribute_checkpointed_activations,
                        self.distribute_checkpointed_activations and ( (l > 0) or (mpu.get_pipeline_model_parallel_rank() == 0)),
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(