Commit 03d09af0 authored by Jared Casper's avatar Jared Casper
Browse files

Adding checkpoint_util and associted loader and saver.

parent 89e8d27e
Loading
Loading
Loading
Loading
+7 −21
Original line number Diff line number Diff line
@@ -404,29 +404,15 @@ python tools/create_doc_index.py \

We provide several command line arguments, detailed in the scripts listed below, to handle various zero-shot and fine-tuned downstream tasks. However, you can also finetune your model from a pretrained checkpoint on other corpora as desired. To do so, simply add the `--finetune` flag and adjust the input files and training parameters within the original training script. The iteration count will be reset to zero, and the optimizer and internal state will be reinitialized. If the fine-tuning is interrupted for any reason, be sure to remove the `--finetune` flag before continuing, otherwise the training will start again from the beginning.

Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this. Currently only tensor model parallelism is supported on input and pipeline model parallelsim on the output. This example reads in a model with 2-way tensor model parallelism and writes out a model with 2-way pipeline model parallelism.
Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on fewer GPUs in downstream tasks. The following script accomplishes this. This example reads in a GPT model with 4-way tensor and 4-way pipeline model parallelism and writes out a model with 2-way tensor and 2-way pipeline model parallelism.

<pre>
TENSOR_MODEL_PARALLEL_SIZE=2
TARGET_PIPELINE_MODEL_PARALLEL_SIZE=2

VOCAB_FILE=bert-vocab.txt
CHECKPOINT_PATH=checkpoints/bert_345m

WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
        --model-type BERT \
        --tensor-model-parallel-size $TENSOR_MODEL_PARALLEL_SIZE \
        --pipeline-model-parallel-size 1 \
        --target-pipeline-model-parallel-size $TARGET_PIPELINE_MODEL_PARALLEL_SIZE \
        --tokenizer-type BertWordPieceLowerCase \
        --vocab-file $VOCAB_FILE \
        --num-layers 24 \
        --hidden-size 1024 \
        --num-attention-heads 16 \
        --seq-length 512 \
        --max-position-embeddings 512 \
        --load $CHECKPOINT_PATH
        --save $CHECKPOINT_PATH/merged
python tools/checkpoint_util.py \
        --model-type GPT \
        --load-dir checkpoints/gpt3_tp4_pp4 \
        --save-dir checkpoints/gpt3_tp2_pp2 \
        --target-tensor-parallel-size 2 \
        --target-pipeline-paralle-size 2

</pre>

+11 −1
Original line number Diff line number Diff line
@@ -21,7 +21,7 @@ import os
import torch

def parse_args(extra_args_provider=None, defaults={},
               ignore_unknown_args=False):
               ignore_unknown_args=False, validate=True):
    """Parse all arguments."""
    parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
                                     allow_abbrev=False)
@@ -52,6 +52,11 @@ def parse_args(extra_args_provider=None, defaults={},
    else:
        args = parser.parse_args()

    if validate:
        return validate_args(args, defaults)
    return args

def validate_args(args, defaults={}):
    # Distributed args.
    args.rank = int(os.getenv('RANK', '0'))
    args.world_size = int(os.getenv("WORLD_SIZE", '1'))
@@ -547,6 +552,11 @@ def _add_checkpointing_args(parser):
                       help='Load model for finetuning. Do not load optimizer '
                       'or rng state from checkpoint and set iteration to 0. '
                       'Assumed when loading a release checkpoint.')
    group.add_argument('--no-initialization', action='store_false',
                       help='Do not perform initialization when building model, '
                       'can reduce startup time when definitely loading from a '
                       'checkpoint',
                       dest='perform_initialization')

    return parser

+150 −39
Original line number Diff line number Diff line
@@ -80,27 +80,56 @@ def ensure_directory_exists(filename):
        os.makedirs(dirname)


def get_checkpoint_name(checkpoints_path, iteration,
                        release=False):
def get_checkpoint_name(checkpoints_path, iteration, release=False,
                        pipeline_parallel_size=None, tensor_rank=None, pipeline_rank=None):
    """A unified checkpoint name."""
    if release:
        directory = 'release'
    else:
        directory = 'iter_{:07d}'.format(iteration)
    # Use both the tensor and pipeline MP rank.
    if mpu.get_pipeline_model_parallel_world_size() == 1:
    if pipeline_parallel_size is None:
        parallel_size = mpu.get_pipeline_model_parallel_world_size()
    if tensor_rank is None:
        tensor_rank = mpu.get_tensor_model_parallel_rank()
    if pipeline_rank is None:
        pipeline_rank = mpu.get_pipeline_model_parallel_rank()
    if pipeline_parallel_size == 1:
        return os.path.join(checkpoints_path, directory,
                            'mp_rank_{:02d}'.format(
                                mpu.get_tensor_model_parallel_rank()),
                            f'mp_rank_{tensor_rank:02d}',
                            'model_optim_rng.pt')
    return os.path.join(checkpoints_path, directory,
                        'mp_rank_{:02d}_{:03d}'.format(
                            mpu.get_tensor_model_parallel_rank(),
                            mpu.get_pipeline_model_parallel_rank()),
                        f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}',
                        'model_optim_rng.pt')

def find_checkpoint_rank_0(checkpoints_path, iteration, release=False):
    """Finds the checkpoint for rank 0 without knowing if we are using
    pipeline parallelism or not.

    Since the checkpoint naming scheme changes if pipeline parallelism
    is present, we need to look for both naming schemes if we don't
    know if the checkpoint has pipeline parallelism.

    """

    # Look for checkpoint with no pipelining
    filename = get_checkpoint_name(checkpoints_path, iteration, release,
                                   pipeline_parallel_size=1,
                                   tensor_rank=0, pipeline_rank=0)
    if os.path.isfile(filename):
        return filename

    # Look for checkpoint with pipelining
    filename = get_checkpoint_name(checkpoints_path, iteration, release,
                                   pipeline_parallel_size=2,
                                   tensor_rank=0, pipeline_rank=0)
    if os.path.isfile(filename):
        return filename

    return None

def get_checkpoint_tracker_filename(checkpoints_path):

    """Tracker file rescords the latest chckpoint during
    training to restart from."""
    return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
@@ -125,6 +154,7 @@ def read_metadata(tracker_filename):
        tracker_filename)

    # Get the max iteration retrieved across the ranks.
    if torch.distributed.is_initialized():
        iters_cuda = torch.cuda.LongTensor([iteration])
        torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
        max_iter = iters_cuda[0].item()
@@ -137,6 +167,11 @@ def read_metadata(tracker_filename):
                  'metadata while max iteration across the ranks '
                  'is {}, replacing it with max iteration.'.format(
                      rank, iteration, max_iter), flush=True)
    else:
        # When loading a checkpoint outside of training (for example,
        # when editing it), we might not have torch distributed
        # initialized, in this case, just assume we have the latest
        max_iter = iteration
    return max_iter, release


@@ -270,35 +305,38 @@ def fix_query_key_value_ordering(model, checkpoint_version):
        print_rank_0(" succesfully fixed query-key-values ordering for"
                    " checkpoint version {}".format(checkpoint_version))

def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True):
    """Load a model checkpoint and return the iteration.
    strict (bool): whether to strictly enforce that the keys in
        :attr:`state_dict` of the checkpoint match the names of
        parameters and buffers in model.
def _load_base_checkpoint(load_dir, rank0=False):
    """ Load the base state_dict from the given directory

    If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
    """
    args = get_args()
    load_dir = getattr(args, load_arg)

    model = utils.unwrap_model(model)

    # Read the tracker file and set the iteration.
    tracker_filename = get_checkpoint_tracker_filename(load_dir)

    # If no tracker file, return iretation zero.
    # If no tracker file, return nothing
    if not os.path.isfile(tracker_filename):
        if not rank0:
            print_rank_0('WARNING: could not find the metadata file {} '.format(
                tracker_filename))
            print_rank_0('    will not load any checkpoints and will start from '
                         'random')
        return 0
        return None, False

    # Otherwise, read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
    iteration, release = read_metadata(tracker_filename)

    # Checkpoint.
    if rank0:
        checkpoint_name = find_checkpoint_rank_0(load_dir, iteration, release)
    else:
        checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
    print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}')
        if release:
            print_rank_0(f' loading release checkpoint from {load_dir}')
        else:
            print_rank_0(f' loading checkpoint from {load_dir} at iteration {iteration}')

    # Load the checkpoint.
    try:
@@ -306,6 +344,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
    except ModuleNotFoundError:
        from megatron.fp16_deprecated import loss_scaler
        # For backward compatibility.
        if not rank0:
            print_rank_0(' > deserializing using the old code structure ...')
        sys.modules['fp16.loss_scaler'] = sys.modules[
            'megatron.fp16_deprecated.loss_scaler']
@@ -319,6 +358,79 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
        print_rank_0(e)
        sys.exit()

    return state_dict, release

def load_args_from_checkpoint(args, load_arg='load'):
    """Set any arguments that are not currently set from the checkpoint
    specified in the arguments.

    Returns the same args NameSpace with the new values added/updated.

    If no checkpoint is specified in args, or if the checkpoint is
    there but invalid, the arguments will not be modified

    """
    load_dir = getattr(args, load_arg)

    if load_dir is None:
        return args

    state_dict, release = _load_base_checkpoint(load_dir, True)

    if not state_dict:
        return args

    if 'args' not in state_dict:
        return args

    checkpoint_args = state_dict['args']
    checkpoint_version = state_dict.get('checkpoint_version', 0)
    args.iteration = state_dict['iteration']

    def _set_arg(arg_name, old_arg_name=None, force=False):
        if not force and getattr(args, arg_name, None) is not None:
            return

        if old_arg_name is not None:
            checkpoint_value = getattr(checkpoint_args, old_arg_name, None)
        else:
            checkpoint_value = getattr(checkpoint_args, arg_name, None)

        if checkpoint_value is not None:
            print(f"Setting {arg_name} to {checkpoint_value}")
            setattr(args, arg_name, checkpoint_value)

    _set_arg('num_layers')
    _set_arg('hidden_size')
    _set_arg('ffn_hidden_size')
    _set_arg('seq_length')
    _set_arg('num_attention_heads')
    _set_arg('kv_channels')
    _set_arg('max_position_embeddings')
    _set_arg('tokenizer_type')
    _set_arg('padded_vocab_size')
    if checkpoint_version < 3.0:
        _set_arg('tensor_model_parallel_size',
                 'model_parallel_size')
    else:
        _set_arg('tensor_model_parallel_size', force=True)
        _set_arg('pipeline_model_parallel_size', force=True)
        _set_arg('num_layers_per_virtual_pipeline_stage')
    return args

def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True):
    """Load a model checkpoint and return the iteration.
    strict (bool): whether to strictly enforce that the keys in
        :attr:`state_dict` of the checkpoint match the names of
        parameters and buffers in model.
    """
    args = get_args()
    load_dir = getattr(args, load_arg)

    model = utils.unwrap_model(model)

    state_dict, release = _load_base_checkpoint(load_dir, False)

    # set checkpoint version
    set_checkpoint_version(state_dict.get('checkpoint_version', 0))

@@ -445,4 +557,3 @@ def load_biencoder_checkpoint(model, only_query_model=False,
        print(' successfully loaded {}'.format(checkpoint_name))

    return model
+13 −5
Original line number Diff line number Diff line
@@ -77,11 +77,15 @@ def get_timers():


def set_global_variables(extra_args_provider=None, args_defaults={},
                         ignore_unknown_args=False):
                         ignore_unknown_args=False, parse_args=True):
    """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
    if parse_args:
        args = _parse_args(extra_args_provider=extra_args_provider,
                           defaults=args_defaults,
                           ignore_unknown_args=ignore_unknown_args)
    else:
        _ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
        args = get_args()
    _build_num_microbatches_calculator(args)
    if args.vocab_file:
        _ = _build_tokenizer(args)
@@ -89,6 +93,9 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
    _set_adlr_autoresume(args)
    _set_timers()

def set_args(args):
    global _GLOBAL_ARGS
    _GLOBAL_ARGS = args

def _parse_args(extra_args_provider=None, defaults={},
                ignore_unknown_args=False):
@@ -97,7 +104,8 @@ def _parse_args(extra_args_provider=None, defaults={},
    _ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
    _GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider,
                              defaults=defaults,
                              ignore_unknown_args=ignore_unknown_args)
                              ignore_unknown_args=ignore_unknown_args,
                              validate=True)
    return _GLOBAL_ARGS


+4 −2
Original line number Diff line number Diff line
@@ -141,6 +141,7 @@ class Embedding(MegatronModule):
            max_sequence_length, self.hidden_size)
        self._position_embeddings_key = 'position_embeddings'
        # Initialize the position embeddings.
        if args.perform_initialization:
            self.init_method(self.position_embeddings.weight)

        # Token type embedding.
@@ -152,6 +153,7 @@ class Embedding(MegatronModule):
            self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
                                                           self.hidden_size)
            # Initialize the token-type embeddings.
            if args.perform_initialization:
                self.init_method(self.tokentype_embeddings.weight)
        else:
            self.tokentype_embeddings = None
Loading