Commit 1788c910 authored by Mohammad's avatar Mohammad
Browse files

both bert and gpt2 tested and working

parent 5f8623db
Loading
Loading
Loading
Loading
+10 −10
Original line number Diff line number Diff line
@@ -312,9 +312,16 @@ def _add_data_args(parser):
                       choices=['BertWordPieceLowerCase',
                                'GPT2BPETokenizer'],
                       help='What type of tokenizer to use.')
    parser.add_argument('--data-impl', type=str, default='infer',
    group.add_argument('--data-impl', type=str, default='infer',
                       choices=['lazy', 'cached', 'mmap', 'infer'],
                       help='Implementation of indexed datasets.')
    group.add_argument('--reset-position-ids', action='store_true',
                       help='Reset posistion ids after end-of-document token.')
    group.add_argument('--reset-attention-mask', action='store_true',
                       help='Reset self attention maske after '
                       'end-of-document token.')
    group.add_argument('--eod-mask-loss', action='store_true',
                       help='Mask loss for the end of document tokens.')

    return parser

@@ -340,13 +347,6 @@ def _add_gpt2_args(parser):
    group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt',
                       help='The filename containing all the shards '
                       'sizes for numpy data loader')
    group.add_argument('--reset-position-ids', action='store_true',
                       help='Reset posistion ids after end-of-document token.')
    group.add_argument('--reset-attention-mask', action='store_true',
                       help='Reset self attention maske after '
                       'end-of-document token.')
    group.add_argument('--eod-mask-loss', action='store_true',
                       help='Mask loss for the end of document tokens.')

    return parser

+26 −2
Original line number Diff line number Diff line
@@ -21,8 +21,10 @@ import torch

from megatron import get_args
from megatron import get_adlr_autoresume
from megatron import mpu
from megatron import print_rank_0
from megatron.checkpointing import save_checkpoint
from megatron.data_utils.samplers import DistributedBatchSampler
from megatron.fp16 import FP16_Optimizer


@@ -87,7 +89,30 @@ def check_adlr_autoresume_termination(iteration, model,
        sys.exit(0)


###################################################
def make_data_loader(dataset):
    """Buld dataloader given an input dataset."""
    if dataset is None:
        return None
    args = get_args()

    # Data parallel arguments.
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    global_batch_size = args.batch_size * world_size
    num_workers = args.num_workers

    # Use a simple sampler with distributed batch sampler.
    sampler = torch.utils.data.SequentialSampler(dataset)
    batch_sampler = DistributedBatchSampler(sampler=sampler,
                                            batch_size=global_batch_size,
                                            drop_last=True,
                                            rank=rank,
                                            world_size=world_size)
    # Torch dataloader.
    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True)


def get_ltor_masks_and_position_ids(data,
@@ -145,4 +170,3 @@ def get_ltor_masks_and_position_ids(data,
                    prev_index = i + 1

    return attention_mask, loss_mask, position_ids
+4 −23
Original line number Diff line number Diff line
@@ -23,14 +23,12 @@ from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.data_utils.samplers import DistributedBatchSampler
from megatron.model import BertModel
from megatron.training import pretrain
from megatron.utils import make_data_loader
from megatron.utils import reduce_losses




def model_provider():
    """Build the model."""
    args = get_args()
@@ -151,26 +149,9 @@ def get_train_val_test_data():
            skip_warmup=(not args.mmap_warmup))
        print_rank_0("> finished creating BERT datasets ...")

        def make_data_loader_(dataset):
            if not dataset:
                return None
            # Use a simple sampler with distributed batch sampler.
            sampler = torch.utils.data.SequentialSampler(dataset)
            batch_sampler = DistributedBatchSampler(
                sampler=sampler,
                batch_size=global_batch_size,
                drop_last=True,
                rank=data_parallel_rank,
                world_size=data_parallel_size)
            # Torch dataloader.
            return torch.utils.data.DataLoader(dataset,
                                               batch_sampler=batch_sampler,
                                               num_workers=args.num_workers,
                                               pin_memory=True)

        train_data = make_data_loader_(train_ds)
        valid_data = make_data_loader_(valid_ds)
        test_data = make_data_loader_(test_ds)
        train_data = make_data_loader(train_ds)
        valid_data = make_data_loader(valid_ds)
        test_data = make_data_loader(test_ds)

        do_train = train_data is not None and args.train_iters > 0
        do_valid = valid_data is not None and args.eval_iters > 0
+14 −27
Original line number Diff line number Diff line
@@ -25,10 +25,10 @@ from megatron import get_tokenizer
from megatron import mpu
from megatron import print_rank_0
from megatron.data.gpt2_dataset import GPT2Dataset
from megatron.data_utils.samplers import DistributedBatchSampler
from megatron.model import GPT2Model
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import make_data_loader
from megatron.utils import reduce_losses


@@ -121,32 +121,19 @@ def make_gpt2_dataloaders():
    seq_length = args.seq_length
    initial_seed = args.seed

    # Data parallel arguments.
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    global_batch_size = args.batch_size * world_size
    num_workers = args.num_workers

    def make_data_loader_(data_path):
        # Build the dataset.
        dataset = GPT2Dataset(data_path, input_data_sizes_file,
                              seq_length, initial_seed)
        # Use a simple sampler with distributed batch sampler.
        sampler = torch.utils.data.SequentialSampler(dataset)
        batch_sampler = DistributedBatchSampler(sampler=sampler,
                                                batch_size=global_batch_size,
                                                drop_last=True,
                                                rank=rank,
                                                world_size=world_size)
        # Torch dataloader.
        return torch.utils.data.DataLoader(dataset,
                                           batch_sampler=batch_sampler,
                                           num_workers=num_workers,
                                           pin_memory=True)

    train = make_data_loader_(os.path.join(args.data_path, 'train'))
    valid = make_data_loader_(os.path.join(args.data_path, 'valid'))
    test = make_data_loader_(os.path.join(args.data_path, 'test'))
    # Build the datasets.
    def build_dataset_(name):
        return GPT2Dataset(os.path.join(args.data_path, name),
                           args.input_data_sizes_file,
                           args.seq_length, args.seed)
    train_ds = build_dataset_('train')
    valid_ds = build_dataset_('valid')
    test_ds = build_dataset_('test')

    # Dataloaders
    train = make_data_loader(train_ds)
    valid = make_data_loader(valid_ds)
    test = make_data_loader(test_ds)

    args.do_train = False
    args.do_valid = False