Commit 12518332 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'dist_act_chkpt' into 'main'

added splitting checkpointed activations across model parallel partitions

See merge request ADLR/megatron-lm!121
parents 930ec4a2 5d29769c
Loading
Loading
Loading
Loading
+9 −0
Original line number Diff line number Diff line
@@ -112,6 +112,11 @@ def parse_args(extra_args_provider=None, defaults={},
    # Mixed precision checks.
    if args.fp16_lm_cross_entropy:
        assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
    # Activation checkpointing.
    if args.distribute_checkpointed_activations:
        assert args.checkpoint_activations, \
            'for distribute-checkpointed-activations to work you '\
            'need to enable checkpoint-activations'

    _print_args(args)
    return args
@@ -200,6 +205,10 @@ def _add_training_args(parser):
    group.add_argument('--checkpoint-activations', action='store_true',
                       help='Checkpoint activation to allow for training '
                       'with larger models, sequences, and batch sizes.')
    group.add_argument('--distribute-checkpointed-activations',
                       action='store_true',
                       help='If set, distribute checkpointed activations '
                       'across model parallel group.')
    group.add_argument('--checkpoint-num-layers', type=int, default=1,
                       help='chunk size (number of layers) for checkpointing.')
    group.add_argument('--train-iters', type=int, default=None,
+12 −0
Original line number Diff line number Diff line
@@ -73,6 +73,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
        # Megatron's MPU is the master. Complete initialization right away.
        finish_mpu_init()

        # Initialize memory buffers.
        _initialize_mem_buffs()
        
        # Autoresume.
        _init_autoresume()
        
@@ -151,3 +154,12 @@ def _write_args_to_tensorboard():
    if writer:
        for arg in vars(args):
            writer.add_text(arg, str(getattr(args, arg)))


def _initialize_mem_buffs():
    """Initialize manually allocated static memory."""
    args = get_args()

    # Initialize memory for checkpointed activations.
    if args.distribute_checkpointed_activations:
        mpu.init_checkpointed_activations_memory_buffer()

megatron/memory.py

0 → 100644
+145 −0
Original line number Diff line number Diff line
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch


# A dictionary of all the memory buffers allocated.
_MEM_BUFFS = dict()


def allocate_mem_buff(name, numel, dtype, track_usage):
    """Allocate a memory buffer."""
    assert name not in _MEM_BUFFS, \
        'memory buffer {} already allocated.'.format(name)
    _MEM_BUFFS[name] = MemoryBuffer(name, numel, dtype, track_usage)
    return _MEM_BUFFS[name]


def get_mem_buff(name):
    """Get the memory buffer."""
    return _MEM_BUFFS[name]


class MemoryBuffer:
    """Contiguous memory buffer.
    Allocate a contiguous memory of type `dtype` and size `numel`. It is
    used to reduce memory fragmentation.

    Usage: After the allocation, the `_start` index is set tot the first
           index of the memory. A memory chunk starting from `_start` index
           can be `allocated` for an input tensor, with the elements of the
           tensor being coppied. The buffer can be reused by resetting the
           `_start` index.

    """
    def __init__(self, name, numel, dtype, track_usage):
        if torch.distributed.get_rank() == 0:
            element_size = torch.tensor([], dtype=dtype).element_size()
            print('> building the {} memory buffer with {} num elements '
                  'and {} dtype ({:.1f} MB)...'.format(
                      name, numel, dtype, numel*element_size/1024/1024),
                  flush=True)
        self.name = name
        self.numel = numel
        self.dtype = dtype
        self.data = torch.empty(self.numel,
                                dtype=self.dtype,
                                device=torch.cuda.current_device(),
                                requires_grad=False)

        # Index tracking the start of the free memory.
        self._start = 0

        # Values used for tracking usage.
        self.track_usage = track_usage
        if self.track_usage:
            self.in_use_value = 0.0
            self.total_value = 0.0


    def reset(self):
        """Reset the buffer start index to the beginning of the buffer."""
        self._start = 0


    def is_in_use(self):
        """Whether the current buffer hold on to any memory."""
        return self._start > 0


    def numel_in_use(self):
        """Return number of elements in use."""
        return self._start


    def add(self, tensor):
        """Allocate a chunk of memory from the buffer to tensor and copy
        the values."""
        assert tensor.dtype == self.dtype, \
            'Input tensor type {} different from buffer type {}'.format(
                tensor.dtype, self.dtype)
        # Number of elements of the input tensor.
        tensor_numel = torch.numel(tensor)
        new_start = self._start + tensor_numel
        assert new_start <= self.numel, \
            'Not enough memory left in the buffer ({} > {})'.format(
                tensor_numel, self.numel - self._start)
        # New tensor is a view into the memory.
        new_tensor = self.data[self._start:new_start]
        self._start = new_start
        new_tensor = new_tensor.view(tensor.shape)
        new_tensor.copy_(tensor)
        # Return a pointer to the new tensor.
        return new_tensor


    def get_data(self):
        """Return the data currently in use."""
        if self.track_usage:
            self.in_use_value += float(self._start)
            self.total_value += float(self.numel)
        return self.data[:self._start]


    def print_average_usage(self):
        """Print memory usage average over time. We would like this value
        to be as high as possible."""
        assert self.track_usage, 'You need to enable track usage.'
        if torch.distributed.get_rank() == 0:
            print(' > usage of {} memory buffer: {:.2f} %'.format(
                self.name, self.in_use_value * 100.0 / self.total_value),
                  flush=True)



class RingMemBuffer:
    """A ring of memory buffers."""

    def __init__(self, name, num_buffers, numel, dtype, track_usage):
        self.num_buffers = num_buffers
        self.buffers = [
            allocate_mem_buff(name+' {}'.format(i), numel, dtype, track_usage)
            for i in range(num_buffers)]
        self._index = -1


    def get_next_buffer(self):
        self._index += 1
        self._index = self._index % self.num_buffers
        buff = self.buffers[self._index]
        assert not buff.is_in_use(), 'buffer is already in use.'
        return buff
+2 −0
Original line number Diff line number Diff line
@@ -411,6 +411,8 @@ class ParallelTransformer(MegatronModule):
                return x_
            return custom_forward

        # Make sure memory is freed.
        mpu.reset_checkpointed_activations_memory_buffer()
        l = 0
        while l < self.num_layers:
            hidden_states = mpu.checkpoint(
+2 −0
Original line number Diff line number Diff line
@@ -45,7 +45,9 @@ from .mappings import scatter_to_model_parallel_region

from .random import checkpoint
from .random import get_cuda_rng_tracker
from .random import init_checkpointed_activations_memory_buffer
from .random import model_parallel_cuda_manual_seed
from .random import reset_checkpointed_activations_memory_buffer

from .utils import divide
from .utils import split_tensor_along_last_dim
Loading