Commit 30b92cf5 authored by mshoeybi's avatar mshoeybi
Browse files

resolved conflicts

parent 8cb389b8
Loading
Loading
Loading
Loading
+2 −13
Original line number Diff line number Diff line
@@ -609,18 +609,6 @@ class ParallelTransformer(MegatronModule):
                return x_
            return custom_forward

<<<<<<< HEAD
        l = 0
        while l < self.num_layers:
            hidden_states = mpu.checkpoint(
                custom(l, l + self.checkpoint_num_layers),
                self.distribute_checkpointed_activations,
                hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
            l += self.checkpoint_num_layers
=======
        # Make sure memory is freed.
        mpu.reset_checkpointed_activations_memory_buffer()

        if self.activations_checkpoint_method == 'uniform':
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
@@ -629,6 +617,7 @@ class ParallelTransformer(MegatronModule):
            while l < self.num_layers:
                hidden_states = mpu.checkpoint(
                    custom(l, l + self.activations_checkpoint_num_layers),
                    self.distribute_checkpointed_activations,
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                l += self.activations_checkpoint_num_layers
        elif self.activations_checkpoint_method == 'block':
@@ -639,13 +628,13 @@ class ParallelTransformer(MegatronModule):
                if l < self.activations_checkpoint_num_layers:
                    hidden_states = mpu.checkpoint(
                        custom(l, l + 1),
                        self.distribute_checkpointed_activations,
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
        else:
            raise ValueError("Invalid activation checkpoint method.")
>>>>>>> main

        return hidden_states