Loading megatron/arguments.py +10 −10 Original line number Diff line number Diff line Loading @@ -312,9 +312,16 @@ def _add_data_args(parser): choices=['BertWordPieceLowerCase', 'GPT2BPETokenizer'], help='What type of tokenizer to use.') parser.add_argument('--data-impl', type=str, default='infer', group.add_argument('--data-impl', type=str, default='infer', choices=['lazy', 'cached', 'mmap', 'infer'], help='Implementation of indexed datasets.') group.add_argument('--reset-position-ids', action='store_true', help='Reset posistion ids after end-of-document token.') group.add_argument('--reset-attention-mask', action='store_true', help='Reset self attention maske after ' 'end-of-document token.') group.add_argument('--eod-mask-loss', action='store_true', help='Mask loss for the end of document tokens.') return parser Loading @@ -340,13 +347,6 @@ def _add_gpt2_args(parser): group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt', help='The filename containing all the shards ' 'sizes for numpy data loader') group.add_argument('--reset-position-ids', action='store_true', help='Reset posistion ids after end-of-document token.') group.add_argument('--reset-attention-mask', action='store_true', help='Reset self attention maske after ' 'end-of-document token.') group.add_argument('--eod-mask-loss', action='store_true', help='Mask loss for the end of document tokens.') return parser Loading megatron/utils.py +26 −2 Original line number Diff line number Diff line Loading @@ -21,8 +21,10 @@ import torch from megatron import get_args from megatron import get_adlr_autoresume from megatron import mpu from megatron import print_rank_0 from megatron.checkpointing import save_checkpoint from megatron.data_utils.samplers import DistributedBatchSampler from megatron.fp16 import FP16_Optimizer Loading Loading @@ -87,7 +89,30 @@ def check_adlr_autoresume_termination(iteration, model, sys.exit(0) ################################################### def make_data_loader(dataset): """Buld dataloader given an input dataset.""" if dataset is None: return None args = get_args() # Data parallel arguments. world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() global_batch_size = args.batch_size * world_size num_workers = args.num_workers # Use a simple sampler with distributed batch sampler. sampler = torch.utils.data.SequentialSampler(dataset) batch_sampler = DistributedBatchSampler(sampler=sampler, batch_size=global_batch_size, drop_last=True, rank=rank, world_size=world_size) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) def get_ltor_masks_and_position_ids(data, Loading Loading @@ -145,4 +170,3 @@ def get_ltor_masks_and_position_ids(data, prev_index = i + 1 return attention_mask, loss_mask, position_ids pretrain_bert.py +4 −23 Original line number Diff line number Diff line Loading @@ -23,14 +23,12 @@ from megatron import get_timers from megatron import mpu from megatron import print_rank_0 from megatron.data.bert_dataset import build_train_valid_test_datasets from megatron.data_utils.samplers import DistributedBatchSampler from megatron.model import BertModel from megatron.training import pretrain from megatron.utils import make_data_loader from megatron.utils import reduce_losses def model_provider(): """Build the model.""" args = get_args() Loading Loading @@ -151,26 +149,9 @@ def get_train_val_test_data(): skip_warmup=(not args.mmap_warmup)) print_rank_0("> finished creating BERT datasets ...") def make_data_loader_(dataset): if not dataset: return None # Use a simple sampler with distributed batch sampler. sampler = torch.utils.data.SequentialSampler(dataset) batch_sampler = DistributedBatchSampler( sampler=sampler, batch_size=global_batch_size, drop_last=True, rank=data_parallel_rank, world_size=data_parallel_size) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) train_data = make_data_loader_(train_ds) valid_data = make_data_loader_(valid_ds) test_data = make_data_loader_(test_ds) train_data = make_data_loader(train_ds) valid_data = make_data_loader(valid_ds) test_data = make_data_loader(test_ds) do_train = train_data is not None and args.train_iters > 0 do_valid = valid_data is not None and args.eval_iters > 0 Loading pretrain_gpt2.py +14 −27 Original line number Diff line number Diff line Loading @@ -25,10 +25,10 @@ from megatron import get_tokenizer from megatron import mpu from megatron import print_rank_0 from megatron.data.gpt2_dataset import GPT2Dataset from megatron.data_utils.samplers import DistributedBatchSampler from megatron.model import GPT2Model from megatron.training import pretrain from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import make_data_loader from megatron.utils import reduce_losses Loading Loading @@ -121,32 +121,19 @@ def make_gpt2_dataloaders(): seq_length = args.seq_length initial_seed = args.seed # Data parallel arguments. world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() global_batch_size = args.batch_size * world_size num_workers = args.num_workers def make_data_loader_(data_path): # Build the dataset. dataset = GPT2Dataset(data_path, input_data_sizes_file, seq_length, initial_seed) # Use a simple sampler with distributed batch sampler. sampler = torch.utils.data.SequentialSampler(dataset) batch_sampler = DistributedBatchSampler(sampler=sampler, batch_size=global_batch_size, drop_last=True, rank=rank, world_size=world_size) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) train = make_data_loader_(os.path.join(args.data_path, 'train')) valid = make_data_loader_(os.path.join(args.data_path, 'valid')) test = make_data_loader_(os.path.join(args.data_path, 'test')) # Build the datasets. def build_dataset_(name): return GPT2Dataset(os.path.join(args.data_path, name), args.input_data_sizes_file, args.seq_length, args.seed) train_ds = build_dataset_('train') valid_ds = build_dataset_('valid') test_ds = build_dataset_('test') # Dataloaders train = make_data_loader(train_ds) valid = make_data_loader(valid_ds) test = make_data_loader(test_ds) args.do_train = False args.do_valid = False Loading Loading
megatron/arguments.py +10 −10 Original line number Diff line number Diff line Loading @@ -312,9 +312,16 @@ def _add_data_args(parser): choices=['BertWordPieceLowerCase', 'GPT2BPETokenizer'], help='What type of tokenizer to use.') parser.add_argument('--data-impl', type=str, default='infer', group.add_argument('--data-impl', type=str, default='infer', choices=['lazy', 'cached', 'mmap', 'infer'], help='Implementation of indexed datasets.') group.add_argument('--reset-position-ids', action='store_true', help='Reset posistion ids after end-of-document token.') group.add_argument('--reset-attention-mask', action='store_true', help='Reset self attention maske after ' 'end-of-document token.') group.add_argument('--eod-mask-loss', action='store_true', help='Mask loss for the end of document tokens.') return parser Loading @@ -340,13 +347,6 @@ def _add_gpt2_args(parser): group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt', help='The filename containing all the shards ' 'sizes for numpy data loader') group.add_argument('--reset-position-ids', action='store_true', help='Reset posistion ids after end-of-document token.') group.add_argument('--reset-attention-mask', action='store_true', help='Reset self attention maske after ' 'end-of-document token.') group.add_argument('--eod-mask-loss', action='store_true', help='Mask loss for the end of document tokens.') return parser Loading
megatron/utils.py +26 −2 Original line number Diff line number Diff line Loading @@ -21,8 +21,10 @@ import torch from megatron import get_args from megatron import get_adlr_autoresume from megatron import mpu from megatron import print_rank_0 from megatron.checkpointing import save_checkpoint from megatron.data_utils.samplers import DistributedBatchSampler from megatron.fp16 import FP16_Optimizer Loading Loading @@ -87,7 +89,30 @@ def check_adlr_autoresume_termination(iteration, model, sys.exit(0) ################################################### def make_data_loader(dataset): """Buld dataloader given an input dataset.""" if dataset is None: return None args = get_args() # Data parallel arguments. world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() global_batch_size = args.batch_size * world_size num_workers = args.num_workers # Use a simple sampler with distributed batch sampler. sampler = torch.utils.data.SequentialSampler(dataset) batch_sampler = DistributedBatchSampler(sampler=sampler, batch_size=global_batch_size, drop_last=True, rank=rank, world_size=world_size) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) def get_ltor_masks_and_position_ids(data, Loading Loading @@ -145,4 +170,3 @@ def get_ltor_masks_and_position_ids(data, prev_index = i + 1 return attention_mask, loss_mask, position_ids
pretrain_bert.py +4 −23 Original line number Diff line number Diff line Loading @@ -23,14 +23,12 @@ from megatron import get_timers from megatron import mpu from megatron import print_rank_0 from megatron.data.bert_dataset import build_train_valid_test_datasets from megatron.data_utils.samplers import DistributedBatchSampler from megatron.model import BertModel from megatron.training import pretrain from megatron.utils import make_data_loader from megatron.utils import reduce_losses def model_provider(): """Build the model.""" args = get_args() Loading Loading @@ -151,26 +149,9 @@ def get_train_val_test_data(): skip_warmup=(not args.mmap_warmup)) print_rank_0("> finished creating BERT datasets ...") def make_data_loader_(dataset): if not dataset: return None # Use a simple sampler with distributed batch sampler. sampler = torch.utils.data.SequentialSampler(dataset) batch_sampler = DistributedBatchSampler( sampler=sampler, batch_size=global_batch_size, drop_last=True, rank=data_parallel_rank, world_size=data_parallel_size) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) train_data = make_data_loader_(train_ds) valid_data = make_data_loader_(valid_ds) test_data = make_data_loader_(test_ds) train_data = make_data_loader(train_ds) valid_data = make_data_loader(valid_ds) test_data = make_data_loader(test_ds) do_train = train_data is not None and args.train_iters > 0 do_valid = valid_data is not None and args.eval_iters > 0 Loading
pretrain_gpt2.py +14 −27 Original line number Diff line number Diff line Loading @@ -25,10 +25,10 @@ from megatron import get_tokenizer from megatron import mpu from megatron import print_rank_0 from megatron.data.gpt2_dataset import GPT2Dataset from megatron.data_utils.samplers import DistributedBatchSampler from megatron.model import GPT2Model from megatron.training import pretrain from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import make_data_loader from megatron.utils import reduce_losses Loading Loading @@ -121,32 +121,19 @@ def make_gpt2_dataloaders(): seq_length = args.seq_length initial_seed = args.seed # Data parallel arguments. world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() global_batch_size = args.batch_size * world_size num_workers = args.num_workers def make_data_loader_(data_path): # Build the dataset. dataset = GPT2Dataset(data_path, input_data_sizes_file, seq_length, initial_seed) # Use a simple sampler with distributed batch sampler. sampler = torch.utils.data.SequentialSampler(dataset) batch_sampler = DistributedBatchSampler(sampler=sampler, batch_size=global_batch_size, drop_last=True, rank=rank, world_size=world_size) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) train = make_data_loader_(os.path.join(args.data_path, 'train')) valid = make_data_loader_(os.path.join(args.data_path, 'valid')) test = make_data_loader_(os.path.join(args.data_path, 'test')) # Build the datasets. def build_dataset_(name): return GPT2Dataset(os.path.join(args.data_path, name), args.input_data_sizes_file, args.seq_length, args.seed) train_ds = build_dataset_('train') valid_ds = build_dataset_('valid') test_ds = build_dataset_('test') # Dataloaders train = make_data_loader(train_ds) valid = make_data_loader(valid_ds) test = make_data_loader(test_ds) args.do_train = False args.do_valid = False Loading