Commit ebbe40cd authored by Raul Puri's avatar Raul Puri
Browse files

Merge branch 'move_vocab_padding_to_utils' into 'master'

Move vocab padding to utils

See merge request ADLR/megatron-lm!6
parents 6b68bb8a 7e46ca58
Loading
Loading
Loading
Loading
+13 −0
Original line number Diff line number Diff line
@@ -185,6 +185,19 @@ def report_memory(name):
    print_rank_0(string)


def vocab_size_with_padding(num_tokens, args):

    after = num_tokens
    multiple = args.make_vocab_size_divisible_by * \
               mpu.get_model_parallel_world_size()
    while (after % multiple) != 0:
        after += 1
    print_rank_0('> padded vocab (size: {}) with {} dummy '
                 'tokens (new size: {})'.format(
                     num_tokens, after - num_tokens, after))
    return after


def initialize_distributed(args):
    """Initialize torch.distributed."""

+6 −12
Original line number Diff line number Diff line
@@ -44,7 +44,7 @@ from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import initialize_distributed
from megatron.utils import set_random_seed
from megatron.utils import wrap_model_for_distributed_training

from megatron.utils import vocab_size_with_padding

def get_model(args):
    """Build the model."""
@@ -477,19 +477,13 @@ def get_train_val_test_data(args):
        ds_type = 'BERT'
        data_config.set_defaults(data_set_type=ds_type, transpose=False)
        (train_data, val_data, test_data), tokenizer = data_config.apply(args)
        before = tokenizer.num_tokens
        after = before
        multiple = args.make_vocab_size_divisible_by * \
                   mpu.get_model_parallel_world_size()
        while (after % multiple) != 0:
            after += 1
        print_rank_0('> padded vocab (size: {}) with {} dummy '
                     'tokens (new size: {})'.format(
                         before, after - before, after))
        num_tokens = vocab_size_with_padding(tokenizer.num_tokens, args)
        # Need to broadcast num_tokens and num_type_tokens.
        token_counts = torch.cuda.LongTensor([after,
        token_counts = torch.cuda.LongTensor([num_tokens,
                                              tokenizer.num_type_tokens,
                                              int(args.do_train), int(args.do_valid), int(args.do_test)])
                                              int(args.do_train),
                                              int(args.do_valid),
                                              int(args.do_test)])
    else:
        token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])

+7 −10
Original line number Diff line number Diff line
@@ -43,6 +43,7 @@ from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import initialize_distributed
from megatron.utils import set_random_seed
from megatron.utils import wrap_model_for_distributed_training
from megatron.utils import vocab_size_with_padding

from gpt2_data_loader import make_gpt2_dataloaders

@@ -509,17 +510,13 @@ def get_train_val_test_data(args):
            num_tokens = tokenizer.num_tokens
            eod_token = tokenizer.get_command('eos').Id
            assert eod_token == tokenizer.get_command('pad').Id
        before = num_tokens
        after = before
        multiple = args.make_vocab_size_divisible_by * \
                   mpu.get_model_parallel_world_size()
        while (after % multiple) != 0:
            after += 1
        print_rank_0('> padded vocab (size: {}) with {} dummy '
                     'tokens (new size: {})'.format(
                         before, after - before, after))
        # pad.
        num_tokens = vocab_size_with_padding(num_tokens, args)
        print_rank_0('> found end-of-document token: {}'.format(eod_token))
        token_counts = torch.cuda.LongTensor([after, eod_token, int(args.do_train), int(args.do_valid), int(args.do_test)])
        token_counts = torch.cuda.LongTensor([num_tokens, eod_token,
                                              int(args.do_train),
                                              int(args.do_valid),
                                              int(args.do_test)])
    else:
        token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])