Commit 6b5ae488 authored by rprenger's avatar rprenger
Browse files

Fixing merge conflicts with main

parents 3d718bfc 3860e995
Loading
Loading
Loading
Loading
+7 −1
Original line number Diff line number Diff line
@@ -240,9 +240,15 @@ def parse_args(extra_args_provider=None, defaults={},
            'residual connection in fp32 only supported when using fp16 or bf16.'
    # Activation checkpointing.
    if args.distribute_checkpointed_activations:
        assert args.tensor_model_parallel_size > 1, 'can distribute ' \
            'checkpointed activations only across tensor model ' \
            'parallel groups'
        assert args.activations_checkpoint_method is not None, \
            'for distribute-checkpointed-activations to work you '\
            'need to use a valid checkpoint-activation method (\'uniform\' or \'block\')'
            'need to use a activation-checkpoint method '
        assert args.num_layers_per_virtual_pipeline_stage is None, \
            'currently distrobuted checkpoint activations only supported for ' \
            'nointerleaved pipeline parallelism'

    _print_args(args)
    return args
+23 −9
Original line number Diff line number Diff line
@@ -64,6 +64,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
            print('> setting random seeds to {} ...'.format(args.seed))
        _set_random_seed(args.seed)

    # Set pytorch JIT layer fusion options.
    _set_jit_fusion_options()

    args = get_args()
    if  args.lazy_mpu_init:
        args.use_cpu_initialization=True
@@ -78,9 +81,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
        # Megatron's MPU is the master. Complete initialization right away.
        finish_mpu_init()

        # Initialize memory buffers.
        _initialize_mem_buffs()
        
        # Autoresume.
        _init_autoresume()

@@ -226,10 +226,24 @@ def write_args_to_tensorboard():
                            global_step=args.iteration)


def _initialize_mem_buffs():
    """Initialize manually allocated static memory."""
    args = get_args()
def _set_jit_fusion_options():
    """Set PyTorch JIT layer fusion options."""
    # flags required to enable jit fusion kernels
    TORCH_MAJOR = int(torch.__version__.split('.')[0])
    TORCH_MINOR = int(torch.__version__.split('.')[1])
    if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
        # nvfuser
        torch._C._jit_set_profiling_executor(True)
        torch._C._jit_set_profiling_mode(True)
        torch._C._jit_override_can_fuse_on_cpu(False)
        torch._C._jit_override_can_fuse_on_gpu(False)
        torch._C._jit_set_texpr_fuser_enabled(False)
        torch._C._jit_set_nvfuser_enabled(True)
        torch._C._debug_set_autodiff_subgraph_inlining(False)
    else:
        # legacy pytorch fuser
        torch._C._jit_set_profiling_mode(False)
        torch._C._jit_set_profiling_executor(False)
        torch._C._jit_override_can_fuse_on_cpu(True)
        torch._C._jit_override_can_fuse_on_gpu(True)
    # Initialize memory for checkpointed activations.
    if args.distribute_checkpointed_activations:
        mpu.init_checkpointed_activations_memory_buffer()
+0 −4
Original line number Diff line number Diff line
@@ -15,10 +15,6 @@

import torch

torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)

###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
+19 −7
Original line number Diff line number Diff line
@@ -27,11 +27,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu

# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)

""" We use the following notation throughout this file:
     h: hidden size
@@ -544,6 +539,7 @@ class ParallelTransformer(MegatronModule):
        # Store activation checkpoiting flag.
        self.activations_checkpoint_method = args.activations_checkpoint_method
        self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
        self.distribute_checkpointed_activations = args.distribute_checkpointed_activations

        # Number of layers.
        assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
@@ -607,8 +603,22 @@ class ParallelTransformer(MegatronModule):
                return x_
            return custom_forward

        # Make sure memory is freed.
        mpu.reset_checkpointed_activations_memory_buffer()
        def distribute_checkpointed_activations_helper(layer_number):
            """Distribute checkpointed activations across the tensor model
               Parallel ranks if the `distribute-checkpointed-activations
               is on and either of the following conditions is met:
                 - it is not the first layer in the in the pipeline stage.
                   The first layer is used in the pipeline parallelism 
                   and changing its shape throws error in the backward pass.
                 - we are at the first pipline stage so the input tensor is
                   not used in pipeline parallelism. Note that no pipeline
                   parallelism is a special case of this.
            """
            not_first_layer_in_pipeline_stage = (layer_number > 0)
            is_first_pipeline_stage = (
                mpu.get_pipeline_model_parallel_rank() == 0)
            return self.distribute_checkpointed_activations and \
                (not_first_layer_in_pipeline_stage or is_first_pipeline_stage)

        if self.activations_checkpoint_method == 'uniform':
            # Uniformly divide the total number of Transformer layers and checkpoint
@@ -618,6 +628,7 @@ class ParallelTransformer(MegatronModule):
            while l < self.num_layers:
                hidden_states = mpu.checkpoint(
                    custom(l, l + self.activations_checkpoint_num_layers),
                    distribute_checkpointed_activations_helper(l),
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                l += self.activations_checkpoint_num_layers
        elif self.activations_checkpoint_method == 'block':
@@ -628,6 +639,7 @@ class ParallelTransformer(MegatronModule):
                if l < self.activations_checkpoint_num_layers:
                    hidden_states = mpu.checkpoint(
                        custom(l, l + 1),
                        distribute_checkpointed_activations_helper(l),
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(
+0 −2
Original line number Diff line number Diff line
@@ -56,9 +56,7 @@ from .mappings import scatter_to_tensor_model_parallel_region

from .random import checkpoint
from .random import get_cuda_rng_tracker
from .random import init_checkpointed_activations_memory_buffer
from .random import model_parallel_cuda_manual_seed
from .random import reset_checkpointed_activations_memory_buffer
from .random import gather_split_1d_tensor
from .random import split_tensor_into_1d_equal_chunks

Loading