Commit 5d29769c authored by mohammad's avatar mohammad
Browse files

addressed Jareds comments

parent d6c4248b
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -112,6 +112,11 @@ def parse_args(extra_args_provider=None, defaults={},
    # Mixed precision checks.
    if args.fp16_lm_cross_entropy:
        assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
    # Activation checkpointing.
    if args.distribute_checkpointed_activations:
        assert args.checkpoint_activations, \
            'for distribute-checkpointed-activations to work you '\
            'need to enable checkpoint-activations'

    _print_args(args)
    return args
+1 −10
Original line number Diff line number Diff line
@@ -162,13 +162,4 @@ def _initialize_mem_buffs():

    # Initialize memory for checkpointed activations.
    if args.distribute_checkpointed_activations:
        per_layer = args.batch_size * args.max_position_embeddings * \
                    args.hidden_size // args.model_parallel_size
        assert args.num_layers % args.checkpoint_num_layers == 0, \
            'number of layers is not divisible by checkpoint-num-layers'
        num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers
        numel = per_layer * num_checkpointer_layers
        dtype = torch.half
        if not args.fp16:
            dtype = torch.float
        mpu.init_checkpointed_activations_memory_buffer(numel, dtype)
        mpu.init_checkpointed_activations_memory_buffer()
+16 −1
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@ from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager
from torch.utils.checkpoint import detach_variable

from megatron import get_args
from megatron.memory import allocate_mem_buff

from .initialize import get_data_parallel_rank
@@ -40,9 +41,23 @@ _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None


def init_checkpointed_activations_memory_buffer(numel, dtype):
def init_checkpointed_activations_memory_buffer():
    """Initializ the memory buffer for the checkpointed activations."""
    args = get_args()

    per_layer = args.batch_size * args.max_position_embeddings * \
                args.hidden_size // args.model_parallel_size
    assert args.num_layers % args.checkpoint_num_layers == 0, \
        'number of layers is not divisible by checkpoint-num-layers'
    num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers
    numel = per_layer * num_checkpointer_layers
    dtype = torch.half
    if not args.fp16:
        dtype = torch.float
        
    global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
    assert _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None, \
        'checkpointed activations memory buffer is already allocated.'
    _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff(
        'checkpointed activations', numel, dtype, track_usage=False)