Loading megatron/arguments.py +7 −1 Original line number Diff line number Diff line Loading @@ -240,9 +240,15 @@ def parse_args(extra_args_provider=None, defaults={}, 'residual connection in fp32 only supported when using fp16 or bf16.' # Activation checkpointing. if args.distribute_checkpointed_activations: assert args.tensor_model_parallel_size > 1, 'can distribute ' \ 'checkpointed activations only across tensor model ' \ 'parallel groups' assert args.activations_checkpoint_method is not None, \ 'for distribute-checkpointed-activations to work you '\ 'need to use a valid checkpoint-activation method (\'uniform\' or \'block\')' 'need to use a activation-checkpoint method ' assert args.num_layers_per_virtual_pipeline_stage is None, \ 'currently distrobuted checkpoint activations only supported for ' \ 'nointerleaved pipeline parallelism' _print_args(args) return args Loading megatron/initialize.py +0 −11 Original line number Diff line number Diff line Loading @@ -77,9 +77,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, # Megatron's MPU is the master. Complete initialization right away. finish_mpu_init() # Initialize memory buffers. _initialize_mem_buffs() # Autoresume. _init_autoresume() Loading Loading @@ -224,11 +221,3 @@ def write_args_to_tensorboard(): writer.add_text(arg, str(getattr(args, arg)), global_step=args.iteration) def _initialize_mem_buffs(): """Initialize manually allocated static memory.""" args = get_args() # Initialize memory for checkpointed activations. if args.distribute_checkpointed_activations: mpu.init_checkpointed_activations_memory_buffer() megatron/model/transformer.py +19 −2 Original line number Diff line number Diff line Loading @@ -544,6 +544,7 @@ class ParallelTransformer(MegatronModule): # Store activation checkpoiting flag. self.activations_checkpoint_method = args.activations_checkpoint_method self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers self.distribute_checkpointed_activations = args.distribute_checkpointed_activations # Number of layers. assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \ Loading Loading @@ -607,8 +608,22 @@ class ParallelTransformer(MegatronModule): return x_ return custom_forward # Make sure memory is freed. mpu.reset_checkpointed_activations_memory_buffer() def distribute_checkpointed_activations_helper(layer_number): """Distribute checkpointed activations across the tensor model Parallel ranks if the `distribute-checkpointed-activations is on and either of the following conditions is met: - it is not the first layer in the in the pipeline stage. The first layer is used in the pipeline parallelism and changing its shape throws error in the backward pass. - we are at the first pipline stage so the input tensor is not used in pipeline parallelism. Note that no pipeline parallelism is a special case of this. """ not_first_layer_in_pipeline_stage = (layer_number > 0) is_first_pipeline_stage = ( mpu.get_pipeline_model_parallel_rank() == 0) return self.distribute_checkpointed_activations and \ (not_first_layer_in_pipeline_stage or is_first_pipeline_stage) if self.activations_checkpoint_method == 'uniform': # Uniformly divide the total number of Transformer layers and checkpoint Loading @@ -618,6 +633,7 @@ class ParallelTransformer(MegatronModule): while l < self.num_layers: hidden_states = mpu.checkpoint( custom(l, l + self.activations_checkpoint_num_layers), distribute_checkpointed_activations_helper(l), hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) l += self.activations_checkpoint_num_layers elif self.activations_checkpoint_method == 'block': Loading @@ -628,6 +644,7 @@ class ParallelTransformer(MegatronModule): if l < self.activations_checkpoint_num_layers: hidden_states = mpu.checkpoint( custom(l, l + 1), distribute_checkpointed_activations_helper(l), hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) else: hidden_states = custom(l, l + 1)( Loading megatron/mpu/__init__.py +0 −2 Original line number Diff line number Diff line Loading @@ -56,9 +56,7 @@ from .mappings import scatter_to_tensor_model_parallel_region from .random import checkpoint from .random import get_cuda_rng_tracker from .random import init_checkpointed_activations_memory_buffer from .random import model_parallel_cuda_manual_seed from .random import reset_checkpointed_activations_memory_buffer from .random import gather_split_1d_tensor from .random import split_tensor_into_1d_equal_chunks Loading megatron/mpu/random.py +23 −55 Original line number Diff line number Diff line Loading @@ -37,46 +37,6 @@ from .initialize import get_tensor_model_parallel_world_size _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' # Whether apply model parallelsim to checkpointed hidden states. _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None def init_checkpointed_activations_memory_buffer(): """Initializ the memory buffer for the checkpointed activations.""" args = get_args() per_layer = args.micro_batch_size * args.max_position_embeddings * \ args.hidden_size // args.tensor_model_parallel_size num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size() if args.virtual_pipeline_model_parallel_size is not None: num_layers = num_layers // args.virtual_pipeline_model_parallel_size if args.activations_checkpoint_method == 'uniform': assert num_layers % args.activations_checkpoint_num_layers == 0, \ 'total number of layers is not divisible by checkpoint-chunk_size' num_checkpointer_layers = args.num_layers // args.activations_checkpoint_num_layers elif args.activations_checkpoint_method == 'block': assert args.activations_checkpoint_num_layers <= num_layers, \ 'total number of layers is fewer than the number of layers to checkpoint' num_checkpointer_layers = args.activations_checkpoint_num_layers numel = per_layer * num_checkpointer_layers dtype = torch.half if not args.fp16: dtype = torch.float global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER assert _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None, \ 'checkpointed activations memory buffer is already allocated.' _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff( 'checkpointed activations', numel, dtype, track_usage=False) def reset_checkpointed_activations_memory_buffer(): """Reset the memory used for checkpointing.""" if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None: _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.reset() def _set_cuda_rng_state(new_state, device=-1): """Sets the random number generator state of the current GPU. Loading Loading @@ -110,13 +70,20 @@ def _set_cuda_rng_state(new_state, device=-1): _lazy_call(cb) def split_tensor_into_1d_equal_chunks(tensor): def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): """Break a tensor into equal 1D chunks.""" data = tensor.view(-1) partition_size = torch.numel(data) // get_tensor_model_parallel_world_size() partition_size = torch.numel(tensor) // \ get_tensor_model_parallel_world_size() start_index = partition_size * get_tensor_model_parallel_rank() end_index = start_index + partition_size return data[start_index:end_index] if new_buffer: data = torch.empty(partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) data.copy_(tensor.view(-1)[start_index:end_index]) else: data = tensor.view(-1)[start_index:end_index] return data def gather_split_1d_tensor(tensor): Loading Loading @@ -259,8 +226,10 @@ class CheckpointFunction(torch.autograd.Function): tracked/set/reset. """ @staticmethod def forward(ctx, run_function, *args): def forward(ctx, run_function, distribute_checkpointed_activations, *args): ctx.run_function = run_function ctx.distribute_checkpointed_activations \ = distribute_checkpointed_activations # Copy the rng states. ctx.fwd_cpu_rng_state = torch.get_rng_state() Loading @@ -272,16 +241,14 @@ class CheckpointFunction(torch.autograd.Function): # Divide hidden states across model parallel group and only keep # the chunk corresponding to the current rank. if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None: if distribute_checkpointed_activations: ctx.input_0_shape = args[0].data.shape args[0].data = split_tensor_into_1d_equal_chunks(args[0].data) args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add( args[0].data) args[0].data = split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True) # Store everything. ctx.save_for_backward(*args) return outputs @staticmethod Loading @@ -290,7 +257,7 @@ class CheckpointFunction(torch.autograd.Function): raise RuntimeError("Checkpointing is not compatible with .grad(), " "please use .backward() if possible") inputs = ctx.saved_tensors if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None: if ctx.distribute_checkpointed_activations: inputs[0].data = gather_split_1d_tensor(inputs[0].data) inputs[0].data = inputs[0].data.view(ctx.input_0_shape) Loading Loading @@ -319,10 +286,11 @@ class CheckpointFunction(torch.autograd.Function): torch.autograd.backward(outputs, args) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) return (None,) + grads return (None, None) + grads def checkpoint(function, *args): def checkpoint(function, distribute_checkpointed_activations, *args): """Checkpoint a model or part of the model. This has been directly copied from torch.utils.checkpoint.""" return CheckpointFunction.apply(function, *args) return CheckpointFunction.apply(function, distribute_checkpointed_activations, *args) Loading
megatron/arguments.py +7 −1 Original line number Diff line number Diff line Loading @@ -240,9 +240,15 @@ def parse_args(extra_args_provider=None, defaults={}, 'residual connection in fp32 only supported when using fp16 or bf16.' # Activation checkpointing. if args.distribute_checkpointed_activations: assert args.tensor_model_parallel_size > 1, 'can distribute ' \ 'checkpointed activations only across tensor model ' \ 'parallel groups' assert args.activations_checkpoint_method is not None, \ 'for distribute-checkpointed-activations to work you '\ 'need to use a valid checkpoint-activation method (\'uniform\' or \'block\')' 'need to use a activation-checkpoint method ' assert args.num_layers_per_virtual_pipeline_stage is None, \ 'currently distrobuted checkpoint activations only supported for ' \ 'nointerleaved pipeline parallelism' _print_args(args) return args Loading
megatron/initialize.py +0 −11 Original line number Diff line number Diff line Loading @@ -77,9 +77,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, # Megatron's MPU is the master. Complete initialization right away. finish_mpu_init() # Initialize memory buffers. _initialize_mem_buffs() # Autoresume. _init_autoresume() Loading Loading @@ -224,11 +221,3 @@ def write_args_to_tensorboard(): writer.add_text(arg, str(getattr(args, arg)), global_step=args.iteration) def _initialize_mem_buffs(): """Initialize manually allocated static memory.""" args = get_args() # Initialize memory for checkpointed activations. if args.distribute_checkpointed_activations: mpu.init_checkpointed_activations_memory_buffer()
megatron/model/transformer.py +19 −2 Original line number Diff line number Diff line Loading @@ -544,6 +544,7 @@ class ParallelTransformer(MegatronModule): # Store activation checkpoiting flag. self.activations_checkpoint_method = args.activations_checkpoint_method self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers self.distribute_checkpointed_activations = args.distribute_checkpointed_activations # Number of layers. assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \ Loading Loading @@ -607,8 +608,22 @@ class ParallelTransformer(MegatronModule): return x_ return custom_forward # Make sure memory is freed. mpu.reset_checkpointed_activations_memory_buffer() def distribute_checkpointed_activations_helper(layer_number): """Distribute checkpointed activations across the tensor model Parallel ranks if the `distribute-checkpointed-activations is on and either of the following conditions is met: - it is not the first layer in the in the pipeline stage. The first layer is used in the pipeline parallelism and changing its shape throws error in the backward pass. - we are at the first pipline stage so the input tensor is not used in pipeline parallelism. Note that no pipeline parallelism is a special case of this. """ not_first_layer_in_pipeline_stage = (layer_number > 0) is_first_pipeline_stage = ( mpu.get_pipeline_model_parallel_rank() == 0) return self.distribute_checkpointed_activations and \ (not_first_layer_in_pipeline_stage or is_first_pipeline_stage) if self.activations_checkpoint_method == 'uniform': # Uniformly divide the total number of Transformer layers and checkpoint Loading @@ -618,6 +633,7 @@ class ParallelTransformer(MegatronModule): while l < self.num_layers: hidden_states = mpu.checkpoint( custom(l, l + self.activations_checkpoint_num_layers), distribute_checkpointed_activations_helper(l), hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) l += self.activations_checkpoint_num_layers elif self.activations_checkpoint_method == 'block': Loading @@ -628,6 +644,7 @@ class ParallelTransformer(MegatronModule): if l < self.activations_checkpoint_num_layers: hidden_states = mpu.checkpoint( custom(l, l + 1), distribute_checkpointed_activations_helper(l), hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) else: hidden_states = custom(l, l + 1)( Loading
megatron/mpu/__init__.py +0 −2 Original line number Diff line number Diff line Loading @@ -56,9 +56,7 @@ from .mappings import scatter_to_tensor_model_parallel_region from .random import checkpoint from .random import get_cuda_rng_tracker from .random import init_checkpointed_activations_memory_buffer from .random import model_parallel_cuda_manual_seed from .random import reset_checkpointed_activations_memory_buffer from .random import gather_split_1d_tensor from .random import split_tensor_into_1d_equal_chunks Loading
megatron/mpu/random.py +23 −55 Original line number Diff line number Diff line Loading @@ -37,46 +37,6 @@ from .initialize import get_tensor_model_parallel_world_size _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' # Whether apply model parallelsim to checkpointed hidden states. _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None def init_checkpointed_activations_memory_buffer(): """Initializ the memory buffer for the checkpointed activations.""" args = get_args() per_layer = args.micro_batch_size * args.max_position_embeddings * \ args.hidden_size // args.tensor_model_parallel_size num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size() if args.virtual_pipeline_model_parallel_size is not None: num_layers = num_layers // args.virtual_pipeline_model_parallel_size if args.activations_checkpoint_method == 'uniform': assert num_layers % args.activations_checkpoint_num_layers == 0, \ 'total number of layers is not divisible by checkpoint-chunk_size' num_checkpointer_layers = args.num_layers // args.activations_checkpoint_num_layers elif args.activations_checkpoint_method == 'block': assert args.activations_checkpoint_num_layers <= num_layers, \ 'total number of layers is fewer than the number of layers to checkpoint' num_checkpointer_layers = args.activations_checkpoint_num_layers numel = per_layer * num_checkpointer_layers dtype = torch.half if not args.fp16: dtype = torch.float global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER assert _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None, \ 'checkpointed activations memory buffer is already allocated.' _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff( 'checkpointed activations', numel, dtype, track_usage=False) def reset_checkpointed_activations_memory_buffer(): """Reset the memory used for checkpointing.""" if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None: _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.reset() def _set_cuda_rng_state(new_state, device=-1): """Sets the random number generator state of the current GPU. Loading Loading @@ -110,13 +70,20 @@ def _set_cuda_rng_state(new_state, device=-1): _lazy_call(cb) def split_tensor_into_1d_equal_chunks(tensor): def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): """Break a tensor into equal 1D chunks.""" data = tensor.view(-1) partition_size = torch.numel(data) // get_tensor_model_parallel_world_size() partition_size = torch.numel(tensor) // \ get_tensor_model_parallel_world_size() start_index = partition_size * get_tensor_model_parallel_rank() end_index = start_index + partition_size return data[start_index:end_index] if new_buffer: data = torch.empty(partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) data.copy_(tensor.view(-1)[start_index:end_index]) else: data = tensor.view(-1)[start_index:end_index] return data def gather_split_1d_tensor(tensor): Loading Loading @@ -259,8 +226,10 @@ class CheckpointFunction(torch.autograd.Function): tracked/set/reset. """ @staticmethod def forward(ctx, run_function, *args): def forward(ctx, run_function, distribute_checkpointed_activations, *args): ctx.run_function = run_function ctx.distribute_checkpointed_activations \ = distribute_checkpointed_activations # Copy the rng states. ctx.fwd_cpu_rng_state = torch.get_rng_state() Loading @@ -272,16 +241,14 @@ class CheckpointFunction(torch.autograd.Function): # Divide hidden states across model parallel group and only keep # the chunk corresponding to the current rank. if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None: if distribute_checkpointed_activations: ctx.input_0_shape = args[0].data.shape args[0].data = split_tensor_into_1d_equal_chunks(args[0].data) args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add( args[0].data) args[0].data = split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True) # Store everything. ctx.save_for_backward(*args) return outputs @staticmethod Loading @@ -290,7 +257,7 @@ class CheckpointFunction(torch.autograd.Function): raise RuntimeError("Checkpointing is not compatible with .grad(), " "please use .backward() if possible") inputs = ctx.saved_tensors if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None: if ctx.distribute_checkpointed_activations: inputs[0].data = gather_split_1d_tensor(inputs[0].data) inputs[0].data = inputs[0].data.view(ctx.input_0_shape) Loading Loading @@ -319,10 +286,11 @@ class CheckpointFunction(torch.autograd.Function): torch.autograd.backward(outputs, args) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) return (None,) + grads return (None, None) + grads def checkpoint(function, *args): def checkpoint(function, distribute_checkpointed_activations, *args): """Checkpoint a model or part of the model. This has been directly copied from torch.utils.checkpoint.""" return CheckpointFunction.apply(function, *args) return CheckpointFunction.apply(function, distribute_checkpointed_activations, *args)