Commit 3366a5b0 authored by Mohammad's avatar Mohammad
Browse files

refactored pretrain-bert

parent 27e14f82
Loading
Loading
Loading
Loading
+14 −27
Original line number Diff line number Diff line
@@ -20,15 +20,15 @@ import torch.nn.functional as F

from megatron import get_args
from megatron import get_timers

from megatron import mpu
from megatron.model import BertModel
from megatron import print_rank_0
from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding
from megatron.training import pretrain
from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.data_utils.samplers import DistributedBatchSampler
from megatron.model import BertModel
from megatron.training import pretrain
from megatron.utils import reduce_losses




def model_provider():
@@ -176,36 +176,23 @@ def get_train_val_test_data():
        do_valid = valid_data is not None and args.eval_iters > 0
        do_test = test_data is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        num_tokens = vocab_size_with_padding(train_ds.num_tokens(), args)
        token_counts = torch.cuda.LongTensor([num_tokens,
                                              2, # hard coded num_type_tokens
                                              int(do_train),
                                              int(do_valid),
                                              int(do_test)])
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(token_counts,
    torch.distributed.broadcast(flags,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    args.vocab_size = token_counts[0].item()
    args.tokentype_size = token_counts[1].item()
    args.do_train = token_counts[2].item()
    args.do_valid = token_counts[3].item()
    args.do_test = token_counts[4].item()
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

    return train_data, valid_data, test_data


if __name__ == "__main__":

    '''
    from megatron.initialize import initialize_megatron
    initialize_megatron(args_defaults={
        'tokenizer_type': 'BertWordPieceLowerCase'})
    exit()
    '''
    pretrain(get_train_val_test_data,
             model_provider, forward_step,
    pretrain(get_train_val_test_data, model_provider, forward_step,
             args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})