Loading megatron/arguments.py +5 −0 Original line number Diff line number Diff line Loading @@ -112,6 +112,11 @@ def parse_args(extra_args_provider=None, defaults={}, # Mixed precision checks. if args.fp16_lm_cross_entropy: assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' # Activation checkpointing. if args.distribute_checkpointed_activations: assert args.checkpoint_activations, \ 'for distribute-checkpointed-activations to work you '\ 'need to enable checkpoint-activations' _print_args(args) return args Loading megatron/initialize.py +1 −10 Original line number Diff line number Diff line Loading @@ -162,13 +162,4 @@ def _initialize_mem_buffs(): # Initialize memory for checkpointed activations. if args.distribute_checkpointed_activations: per_layer = args.batch_size * args.max_position_embeddings * \ args.hidden_size // args.model_parallel_size assert args.num_layers % args.checkpoint_num_layers == 0, \ 'number of layers is not divisible by checkpoint-num-layers' num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers numel = per_layer * num_checkpointer_layers dtype = torch.half if not args.fp16: dtype = torch.float mpu.init_checkpointed_activations_memory_buffer(numel, dtype) mpu.init_checkpointed_activations_memory_buffer() megatron/mpu/random.py +16 −1 Original line number Diff line number Diff line Loading @@ -24,6 +24,7 @@ from torch import _C from torch.cuda import _lazy_call, device as device_ctx_manager from torch.utils.checkpoint import detach_variable from megatron import get_args from megatron.memory import allocate_mem_buff from .initialize import get_data_parallel_rank Loading @@ -40,9 +41,23 @@ _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None def init_checkpointed_activations_memory_buffer(numel, dtype): def init_checkpointed_activations_memory_buffer(): """Initializ the memory buffer for the checkpointed activations.""" args = get_args() per_layer = args.batch_size * args.max_position_embeddings * \ args.hidden_size // args.model_parallel_size assert args.num_layers % args.checkpoint_num_layers == 0, \ 'number of layers is not divisible by checkpoint-num-layers' num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers numel = per_layer * num_checkpointer_layers dtype = torch.half if not args.fp16: dtype = torch.float global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER assert _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None, \ 'checkpointed activations memory buffer is already allocated.' _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff( 'checkpointed activations', numel, dtype, track_usage=False) Loading Loading
megatron/arguments.py +5 −0 Original line number Diff line number Diff line Loading @@ -112,6 +112,11 @@ def parse_args(extra_args_provider=None, defaults={}, # Mixed precision checks. if args.fp16_lm_cross_entropy: assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' # Activation checkpointing. if args.distribute_checkpointed_activations: assert args.checkpoint_activations, \ 'for distribute-checkpointed-activations to work you '\ 'need to enable checkpoint-activations' _print_args(args) return args Loading
megatron/initialize.py +1 −10 Original line number Diff line number Diff line Loading @@ -162,13 +162,4 @@ def _initialize_mem_buffs(): # Initialize memory for checkpointed activations. if args.distribute_checkpointed_activations: per_layer = args.batch_size * args.max_position_embeddings * \ args.hidden_size // args.model_parallel_size assert args.num_layers % args.checkpoint_num_layers == 0, \ 'number of layers is not divisible by checkpoint-num-layers' num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers numel = per_layer * num_checkpointer_layers dtype = torch.half if not args.fp16: dtype = torch.float mpu.init_checkpointed_activations_memory_buffer(numel, dtype) mpu.init_checkpointed_activations_memory_buffer()
megatron/mpu/random.py +16 −1 Original line number Diff line number Diff line Loading @@ -24,6 +24,7 @@ from torch import _C from torch.cuda import _lazy_call, device as device_ctx_manager from torch.utils.checkpoint import detach_variable from megatron import get_args from megatron.memory import allocate_mem_buff from .initialize import get_data_parallel_rank Loading @@ -40,9 +41,23 @@ _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None def init_checkpointed_activations_memory_buffer(numel, dtype): def init_checkpointed_activations_memory_buffer(): """Initializ the memory buffer for the checkpointed activations.""" args = get_args() per_layer = args.batch_size * args.max_position_embeddings * \ args.hidden_size // args.model_parallel_size assert args.num_layers % args.checkpoint_num_layers == 0, \ 'number of layers is not divisible by checkpoint-num-layers' num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers numel = per_layer * num_checkpointer_layers dtype = torch.half if not args.fp16: dtype = torch.float global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER assert _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None, \ 'checkpointed activations memory buffer is already allocated.' _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff( 'checkpointed activations', numel, dtype, track_usage=False) Loading