Loading megatron/model/transformer.py +2 −13 Original line number Diff line number Diff line Loading @@ -609,18 +609,6 @@ class ParallelTransformer(MegatronModule): return x_ return custom_forward <<<<<<< HEAD l = 0 while l < self.num_layers: hidden_states = mpu.checkpoint( custom(l, l + self.checkpoint_num_layers), self.distribute_checkpointed_activations, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) l += self.checkpoint_num_layers ======= # Make sure memory is freed. mpu.reset_checkpointed_activations_memory_buffer() if self.activations_checkpoint_method == 'uniform': # Uniformly divide the total number of Transformer layers and checkpoint # the input activation of each divided chunk. Loading @@ -629,6 +617,7 @@ class ParallelTransformer(MegatronModule): while l < self.num_layers: hidden_states = mpu.checkpoint( custom(l, l + self.activations_checkpoint_num_layers), self.distribute_checkpointed_activations, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) l += self.activations_checkpoint_num_layers elif self.activations_checkpoint_method == 'block': Loading @@ -639,13 +628,13 @@ class ParallelTransformer(MegatronModule): if l < self.activations_checkpoint_num_layers: hidden_states = mpu.checkpoint( custom(l, l + 1), self.distribute_checkpointed_activations, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) else: hidden_states = custom(l, l + 1)( hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) else: raise ValueError("Invalid activation checkpoint method.") >>>>>>> main return hidden_states Loading Loading
megatron/model/transformer.py +2 −13 Original line number Diff line number Diff line Loading @@ -609,18 +609,6 @@ class ParallelTransformer(MegatronModule): return x_ return custom_forward <<<<<<< HEAD l = 0 while l < self.num_layers: hidden_states = mpu.checkpoint( custom(l, l + self.checkpoint_num_layers), self.distribute_checkpointed_activations, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) l += self.checkpoint_num_layers ======= # Make sure memory is freed. mpu.reset_checkpointed_activations_memory_buffer() if self.activations_checkpoint_method == 'uniform': # Uniformly divide the total number of Transformer layers and checkpoint # the input activation of each divided chunk. Loading @@ -629,6 +617,7 @@ class ParallelTransformer(MegatronModule): while l < self.num_layers: hidden_states = mpu.checkpoint( custom(l, l + self.activations_checkpoint_num_layers), self.distribute_checkpointed_activations, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) l += self.activations_checkpoint_num_layers elif self.activations_checkpoint_method == 'block': Loading @@ -639,13 +628,13 @@ class ParallelTransformer(MegatronModule): if l < self.activations_checkpoint_num_layers: hidden_states = mpu.checkpoint( custom(l, l + 1), self.distribute_checkpointed_activations, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) else: hidden_states = custom(l, l + 1)( hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) else: raise ValueError("Invalid activation checkpoint method.") >>>>>>> main return hidden_states Loading