Loading megatron/arguments.py +1 −1 Original line number Diff line number Diff line Loading @@ -240,10 +240,10 @@ def parse_args(extra_args_provider=None, defaults={}, 'residual connection in fp32 only supported when using fp16 or bf16.' # Activation checkpointing. if args.distribute_checkpointed_activations: assert args.tensor_model_parallel_size > 1 assert args.activations_checkpoint_method is not None, \ 'for distribute-checkpointed-activations to work you '\ 'need to use a valid checkpoint-activation method (\'uniform\' or \'block\')' assert args.pipeline_model_parallel_size == 1 _print_args(args) return args Loading megatron/model/transformer.py +2 −2 Original line number Diff line number Diff line Loading @@ -616,7 +616,7 @@ class ParallelTransformer(MegatronModule): while l < self.num_layers: hidden_states = mpu.checkpoint( custom(l, l + self.activations_checkpoint_num_layers), self.distribute_checkpointed_activations, self.distribute_checkpointed_activations and ( (l > 0) or (mpu.get_pipeline_model_parallel_rank() == 0)), hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) l += self.activations_checkpoint_num_layers elif self.activations_checkpoint_method == 'block': Loading @@ -627,7 +627,7 @@ class ParallelTransformer(MegatronModule): if l < self.activations_checkpoint_num_layers: hidden_states = mpu.checkpoint( custom(l, l + 1), self.distribute_checkpointed_activations, self.distribute_checkpointed_activations and ( (l > 0) or (mpu.get_pipeline_model_parallel_rank() == 0)), hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) else: hidden_states = custom(l, l + 1)( Loading Loading
megatron/arguments.py +1 −1 Original line number Diff line number Diff line Loading @@ -240,10 +240,10 @@ def parse_args(extra_args_provider=None, defaults={}, 'residual connection in fp32 only supported when using fp16 or bf16.' # Activation checkpointing. if args.distribute_checkpointed_activations: assert args.tensor_model_parallel_size > 1 assert args.activations_checkpoint_method is not None, \ 'for distribute-checkpointed-activations to work you '\ 'need to use a valid checkpoint-activation method (\'uniform\' or \'block\')' assert args.pipeline_model_parallel_size == 1 _print_args(args) return args Loading
megatron/model/transformer.py +2 −2 Original line number Diff line number Diff line Loading @@ -616,7 +616,7 @@ class ParallelTransformer(MegatronModule): while l < self.num_layers: hidden_states = mpu.checkpoint( custom(l, l + self.activations_checkpoint_num_layers), self.distribute_checkpointed_activations, self.distribute_checkpointed_activations and ( (l > 0) or (mpu.get_pipeline_model_parallel_rank() == 0)), hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) l += self.activations_checkpoint_num_layers elif self.activations_checkpoint_method == 'block': Loading @@ -627,7 +627,7 @@ class ParallelTransformer(MegatronModule): if l < self.activations_checkpoint_num_layers: hidden_states = mpu.checkpoint( custom(l, l + 1), self.distribute_checkpointed_activations, self.distribute_checkpointed_activations and ( (l > 0) or (mpu.get_pipeline_model_parallel_rank() == 0)), hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) else: hidden_states = custom(l, l + 1)( Loading