Commit da8dccd6 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

Merge branch 'main' into vision_transformer

parents a7169297 c0400306
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -370,8 +370,7 @@ python tools/create_doc_index.py \

We provide several command line arguments, detailed in the scripts listed below, to handle various zero-shot and fine-tuned downstream tasks. However, you can also finetune your model from a pretrained checkpoint on other corpora as desired. To do so, simply add the `--finetune` flag and adjust the input files and training parameters within the original training script. The iteration count will be reset to zero, and the optimizer and internal state will be reinitialized. If the fine-tuning is interrupted for any reason, be sure to remove the `--finetune` flag before continuing, otherwise the training will start again from the beginning.

<!--
Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this.
Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this. Currently only tensor model parallelism is supported (not pipeline model parallelism).

<pre>
TENSOR_MODEL_PARALLEL_SIZE=2
@@ -390,9 +389,10 @@ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
        --seq-length 512 \
        --max-position-embeddings 512 \
        --load $CHECKPOINT_PATH
        --save $CHECKPOINT_PATH/merged

</pre>
-->

Several downstream tasks are described for both GPT and BERT models below. They can be run in distributed and model parallel modes with the same changes used in the training scripts.

## GPT Text Generation
+20 −17
Original line number Diff line number Diff line
@@ -92,6 +92,20 @@ def parse_args(extra_args_provider=None, defaults={},
        'longer valid, use --tensor-model-parallel-size instead'
    del args.model_parallel_size

    # Set input defaults.
    for key in defaults:
        # For default to be valid, it should not be provided in the
        # arguments that are passed to the program. We check this by
        # ensuring the arg is set to None.
        if getattr(args, key) is not None:
            if args.rank == 0:
                print('WARNING: overriding default arguments for {key}:{v} \
                       with {key}:{v2}'.format(key=key, v=defaults[key],
                                               v2=getattr(args, key)),
                                               flush=True)
        else:
            setattr(args, key, defaults[key])

    # Batch size.
    assert args.micro_batch_size is not None
    assert args.micro_batch_size > 0
@@ -114,20 +128,6 @@ def parse_args(extra_args_provider=None, defaults={},
    args.consumed_train_samples = 0
    args.consumed_valid_samples = 0

    # Set input defaults.
    for key in defaults:
        # For default to be valid, it should not be provided in the
        # arguments that are passed to the program. We check this by
        # ensuring the arg is set to None.
        if getattr(args, key) is not None:
            if args.rank == 0:
                print('WARNING: overriding default arguments for {key}:{v} \
                       with {key}:{v2}'.format(key=key, v=defaults[key],
                                               v2=getattr(args, key)),
                                               flush=True)
        else:
            setattr(args, key, defaults[key])

    # Iteration-based training.
    if args.train_iters:
        # If we use iteration-based training, make sure the
@@ -263,6 +263,9 @@ def _add_network_size_args(parser):
                       'reasons.')
    group.add_argument('--onnx-safe', type=bool, required=False,
                       help='Use workarounds for known problems with Torch ONNX exporter')
    group.add_argument('--bert-no-binary-head', action='store_false',
                       help='Disable BERT binary head.',
                       dest='bert_binary_head')

    return parser

@@ -442,9 +445,9 @@ def _add_checkpointing_args(parser):
                       help='Do not save current rng state.')
    group.add_argument('--load', type=str, default=None,
                       help='Directory containing a model checkpoint.')
    group.add_argument('--no-load-optim', action='store_true',
    group.add_argument('--no-load-optim', action='store_true', default=None,
                       help='Do not load optimizer when loading checkpoint.')
    group.add_argument('--no-load-rng', action='store_true',
    group.add_argument('--no-load-rng', action='store_true', default=None,
                       help='Do not load rng state when loading checkpoint.')
    group.add_argument('--finetune', action='store_true',
                       help='Load model for finetuning. Do not load optimizer '
@@ -513,7 +516,7 @@ def _add_distributed_args(parser):
                       ' and returns function to complete it instead.'
                       'Also turns on --use-cpu-initialization flag.'
                       'This is for external DDP manager.' )
    group.add_argument('--use-cpu-initialization', action='store_true',
    group.add_argument('--use-cpu-initialization', action='store_true', default=None,
                       help='If set, affine parallel weights initialization uses CPU' )
    return parser

+23 −19
Original line number Diff line number Diff line
@@ -31,8 +31,9 @@ _CHECKPOINT_VERSION = None

def set_checkpoint_version(value):
    global _CHECKPOINT_VERSION
    assert _CHECKPOINT_VERSION is None, \
        "checkpoint version already set"
    if _CHECKPOINT_VERSION is not None:
        assert _CHECKPOINT_VERSION == value, \
            "checkpoint versions do not match"
    _CHECKPOINT_VERSION = value

def get_checkpoint_version():
@@ -112,11 +113,10 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
    if isinstance(model, torchDDP):
        model = model.module

    if torch.distributed.get_rank() == 0:
        print('saving checkpoint at iteration {:7d} to {}'.format(
            iteration, args.save), flush=True)
    print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
        iteration, args.save))

    if mpu.get_data_parallel_rank() == 0:
    if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0:

        # Arguments, iteration, and model.
        state_dict = {}
@@ -147,16 +147,20 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
        torch.save(state_dict, checkpoint_name)

    # Wait so everyone is done (necessary)
    if torch.distributed.is_initialized():
        torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('  successfully saved checkpoint at iteration {:7d} to {}'.format(
            iteration, args.save), flush=True)

    print_rank_0('  successfully saved checkpoint at iteration {:7d} to {}'.format(
        iteration, args.save))

    # And update the latest iteration
    if torch.distributed.get_rank() == 0:
    if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
        tracker_filename = get_checkpoint_tracker_filename(args.save)
        with open(tracker_filename, 'w') as f:
            f.write(str(iteration))

    # Wait so everyone is done (not necessary)
    if torch.distributed.is_initialized():
        torch.distributed.barrier()


@@ -198,9 +202,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True

    # Checkpoint.
    checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
    if torch.distributed.get_rank() == 0:
        print(' loading checkpoint from {} at iteration {}'.format(
            args.load, iteration), flush=True)
    print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}')

    # Load the checkpoint.
    try:
@@ -285,10 +287,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
                         'exiting ...'.format(checkpoint_name))
            sys.exit()

    # Some utilities want to load a checkpoint without distributed being initialized
    if torch.distributed.is_initialized():
        torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('  successfully loaded checkpoint from {} at iteration {}'.format(
            args.load, iteration), flush=True)

    print_rank_0(f'  successfully loaded checkpoint from {args.load} '
                 f'at iteration {iteration}')

    return iteration

+23 −9
Original line number Diff line number Diff line
@@ -36,13 +36,14 @@ class BertDataset(Dataset):

    def __init__(self, name, indexed_dataset, data_prefix,
                 num_epochs, max_num_samples, masked_lm_prob,
                 max_seq_length, short_seq_prob, seed):
                 max_seq_length, short_seq_prob, seed, binary_head):

        # Params to store.
        self.name = name
        self.seed = seed
        self.masked_lm_prob = masked_lm_prob
        self.max_seq_length = max_seq_length
        self.binary_head = binary_head

        # Dataset.
        self.indexed_dataset = indexed_dataset
@@ -55,7 +56,8 @@ class BertDataset(Dataset):
                                                    self.max_seq_length,
                                                    short_seq_prob,
                                                    self.seed,
                                                    self.name)
                                                    self.name,
                                                    self.binary_head)

        # Vocab stuff.
        tokenizer = get_tokenizer()
@@ -81,7 +83,8 @@ class BertDataset(Dataset):
                                     self.vocab_id_to_token_dict,
                                     self.cls_id, self.sep_id,
                                     self.mask_id, self.pad_id,
                                     self.masked_lm_prob, np_rng)
                                     self.masked_lm_prob, np_rng,
                                     self.binary_head)


def get_samples_mapping_(indexed_dataset,
@@ -91,7 +94,8 @@ def get_samples_mapping_(indexed_dataset,
                         max_seq_length,
                         short_seq_prob,
                         seed,
                         name):
                         name,
                         binary_head):
    if not num_epochs:
        if not max_num_samples:
            raise ValueError("Need to specify either max_num_samples "
@@ -137,7 +141,8 @@ def get_samples_mapping_(indexed_dataset,
            max_seq_length - 3,  # account for added tokens
            short_seq_prob,
            seed,
            verbose)
            verbose,
            2 if binary_head else 1)
        print_rank_0(' > done building sapmles index maping')
        np.save(indexmap_filename, samples_mapping, allow_pickle=True)
        print_rank_0(' > saved the index mapping in {}'.format(
@@ -173,7 +178,7 @@ def build_training_sample(sample,
                          target_seq_length, max_seq_length,
                          vocab_id_list, vocab_id_to_token_dict,
                          cls_id, sep_id, mask_id, pad_id,
                          masked_lm_prob, np_rng):
                          masked_lm_prob, np_rng, binary_head):
    """Biuld training sample.

    Arguments:
@@ -193,12 +198,21 @@ def build_training_sample(sample,
              the opper bound whereas the numpy one is exclusive.
    """

    if binary_head:
        # We assume that we have at least two sentences in the sample
        assert len(sample) > 1
    assert target_seq_length <= max_seq_length

    # Divide sample into two segments (A and B).
    tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng)
    if binary_head:
        tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample,
                                                                  np_rng)
    else:
        tokens_a = []
        for j in range(len(sample)):
            tokens_a.extend(sample[j])
        tokens_b = []
        is_next_random = False

    # Truncate to `target_sequence_length`.
    max_num_tokens = target_seq_length
+11 −7
Original line number Diff line number Diff line
@@ -114,7 +114,6 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
    """Truncates a pair of sequences to a maximum sequence length."""
    #print(len_a, len_b, max_num_tokens)
    assert len_a > 0
    assert len_b > 0
    if len_a + len_b <= max_num_tokens:
        return False
    while len_a + len_b > max_num_tokens:
@@ -150,6 +149,7 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
    for token in tokens_b:
        tokens.append(token)
        tokentypes.append(1)
    if tokens_b:
        # [SEP].
        tokens.append(sep_id)
        tokentypes.append(1)
@@ -392,6 +392,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
                                    train_valid_test_num_samples,
                                    max_seq_length, masked_lm_prob,
                                    short_seq_prob, seed, skip_warmup,
                                    binary_head,
                                    dataset_type='standard_bert'):

    if len(data_prefix) == 1:
@@ -401,6 +402,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
                                                max_seq_length, masked_lm_prob,
                                                short_seq_prob, seed,
                                                skip_warmup,
                                                binary_head,
                                                dataset_type=dataset_type)
    # Blending dataset.
    # Parse the values.
@@ -417,7 +419,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
            prefixes[i], data_impl, splits_string,
            datasets_train_valid_test_num_samples[i],
            max_seq_length, masked_lm_prob, short_seq_prob,
            seed, skip_warmup, dataset_type=dataset_type)
            seed, skip_warmup, binary_head, dataset_type=dataset_type)
        if train_ds:
            train_datasets.append(train_ds)
        if valid_ds:
@@ -444,6 +446,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
                                     train_valid_test_num_samples,
                                     max_seq_length, masked_lm_prob,
                                     short_seq_prob, seed, skip_warmup,
                                     binary_head,
                                     dataset_type='standard_bert'):
    
    if dataset_type not in DSET_TYPES:
@@ -503,7 +506,8 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
                num_epochs=None,
                max_num_samples=train_valid_test_num_samples[index],
                max_seq_length=max_seq_length,
                seed=seed
                seed=seed,
                binary_head=binary_head
            )

            if dataset_type == DSET_TYPE_ICT:
Loading