Loading megatron/arguments.py +3 −0 Original line number Diff line number Diff line Loading @@ -72,6 +72,9 @@ def parse_args(extra_args_provider=None, defaults={}, print('using {} for parameters ...'.format(args.params_dtype), flush=True) # Consumed tokens. args.consumed_train_samples = 0 args.consumed_valid_samples = 0 # Set input defaults. for key in defaults: Loading megatron/checkpointing.py +13 −2 Original line number Diff line number Diff line Loading @@ -89,7 +89,8 @@ def get_checkpoint_tracker_filename(checkpoints_path): return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt') def save_checkpoint(iteration, model, optimizer, lr_scheduler): def save_checkpoint(iteration, model, optimizer, lr_scheduler, consumed_train_samples=None, consumed_valid_samples=None): """Save a model checkpoint.""" args = get_args() Loading @@ -103,6 +104,10 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): state_dict['args'] = args state_dict['checkpoint_version'] = 2.0 state_dict['iteration'] = iteration if consumed_train_samples: state_dict['consumed_train_samples'] = consumed_train_samples if consumed_valid_samples: state_dict['consumed_valid_samples'] = consumed_valid_samples state_dict['model'] = model.state_dict_for_save_checkpoint() # Optimizer stuff. Loading Loading @@ -214,6 +219,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): checkpoint_name)) sys.exit() if 'consumed_train_samples' in state_dict: assert args.consumed_train_samples == 0 args.consumed_train_samples = state_dict['consumed_train_samples'] if 'consumed_valid_samples' in state_dict: assert args.consumed_valid_samples == 0 args.consumed_valid_samples = state_dict['consumed_valid_samples'] # Check arguments. if 'args' in state_dict: Loading megatron/data/sampler.py→megatron/data/data_loaders.py +32 −1 Original line number Diff line number Diff line Loading @@ -13,7 +13,38 @@ # See the License for the specific language governing permissions and # limitations under the License. """Megatorn Sampler.""" """Dataloaders.""" import torch from megatron import get_args from megatron import mpu def build_pretraining_data_loader(dataset, consumed_samples): """Buld dataloader given an input dataset.""" if dataset is None: return None args = get_args() world_size = mpu.get_data_parallel_world_size() global_batch_size = args.batch_size * world_size # Megatron sampler batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=consumed_samples, global_batch_size=global_batch_size, rank=mpu.get_data_parallel_rank(), world_size=world_size) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) class MegatronPretrainingSampler: Loading megatron/training.py +29 −13 Original line number Diff line number Diff line Loading @@ -37,7 +37,7 @@ from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import get_params_for_weight_decay_optimization from megatron.model.realm_model import ICTBertModel from megatron.utils import check_adlr_autoresume_termination from megatron.utils import make_data_loader from megatron.data.data_loaders import build_pretraining_data_loader from megatron.utils import report_memory Loading Loading @@ -104,7 +104,9 @@ def pretrain(train_valid_test_dataset_provider, model_provider, iteration, False) if args.save and iteration != 0: save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, lr_scheduler, consumed_train_samples=args.consumed_train_samples, consumed_valid_samples=args.consumed_valid_samples) if args.do_test: # Run on test data. Loading Loading @@ -224,7 +226,8 @@ def setup_model_and_optimizer(model_provider_func): while hasattr(unwrapped_model, 'module'): unwrapped_model = unwrapped_model.module if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'): if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'): print("Initializing ICT from pretrained BERT model", flush=True) unwrapped_model.init_state_dict_from_bert() Loading Loading @@ -414,6 +417,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler, optimizer, lr_scheduler) iteration += 1 args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ args.batch_size # Logging. loss_scale = None Loading @@ -433,7 +438,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler, # Checkpointing if args.save and args.save_interval and \ iteration % args.save_interval == 0: save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, lr_scheduler, consumed_train_samples=args.consumed_train_samples, consumed_valid_samples=args.consumed_valid_samples) # Evaluation if args.eval_interval and iteration % args.eval_interval == 0 and \ Loading Loading @@ -472,6 +479,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): args.eval_iters)) # Forward evaluation. _, loss_dict = forward_step_func(data_iterator, model) args.consumed_valid_samples += mpu.get_data_parallel_world_size() \ * args.batch_size # Reduce across processes. for key in loss_dict: total_loss_dict[key] = total_loss_dict.get(key, 0.) + \ Loading Loading @@ -517,11 +526,19 @@ def build_train_valid_test_data_iterators( (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) print_rank_0('> building train, validation, and test datasets ...') # Data loader only on rank 0 of each model parallel group. if mpu.get_model_parallel_rank() == 0: # Rank, size, and global batch size. # Rank and global batch size. data_parallel_size = mpu.get_data_parallel_world_size() global_batch_size = args.batch_size * data_parallel_size # Backward compatibility, assume fixed batch size. if args.iteration > 0 and args.consumed_train_samples == 0: args.consumed_train_samples = args.iteration * global_batch_size if args.iteration > 0 and args.consumed_valid_samples == 0: args.consumed_valid_samples = (args.iteration // args.eval_interval) * \ args.eval_iters * global_batch_size # Data loader only on rank 0 of each model parallel group. if mpu.get_model_parallel_rank() == 0: # Number of train/valid/test samples. train_iters = args.train_iters Loading @@ -540,12 +557,11 @@ def build_train_valid_test_data_iterators( train_val_test_num_samples) # Build dataloders. comsumed_samples = args.iteration * global_batch_size train_dataloader = make_data_loader(train_ds, comsumed_samples) comsumed_samples = (args.iteration // args.eval_interval) * \ args.eval_iters * global_batch_size valid_dataloader = make_data_loader(valid_ds, comsumed_samples) test_dataloader = make_data_loader(test_ds, comsumed_samples) train_dataloader = build_pretraining_data_loader( train_ds, args.consumed_train_samples) valid_dataloader = build_pretraining_data_loader( valid_ds, args.consumed_valid_samples) test_dataloader = build_pretraining_data_loader(test_ds, 0) # Flags to know if we need to do training/validation/testing. do_train = train_dataloader is not None and args.train_iters > 0 Loading megatron/utils.py +0 −25 Original line number Diff line number Diff line Loading @@ -24,7 +24,6 @@ from megatron import print_rank_0 from megatron import get_adlr_autoresume from megatron import mpu from megatron.checkpointing import save_checkpoint from megatron.data.sampler import MegatronPretrainingSampler from megatron.fp16 import FP16_Optimizer Loading Loading @@ -89,30 +88,6 @@ def check_adlr_autoresume_termination(iteration, model, sys.exit(0) def make_data_loader(dataset, consumed_samples): """Buld dataloader given an input dataset.""" if dataset is None: return None args = get_args() # Data parallel arguments. world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() global_batch_size = args.batch_size * world_size num_workers = args.num_workers # Megatron sampler batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=consumed_samples, global_batch_size=global_batch_size, rank=rank, world_size=world_size) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) def get_ltor_masks_and_position_ids(data, eod_token, reset_position_ids, Loading Loading
megatron/arguments.py +3 −0 Original line number Diff line number Diff line Loading @@ -72,6 +72,9 @@ def parse_args(extra_args_provider=None, defaults={}, print('using {} for parameters ...'.format(args.params_dtype), flush=True) # Consumed tokens. args.consumed_train_samples = 0 args.consumed_valid_samples = 0 # Set input defaults. for key in defaults: Loading
megatron/checkpointing.py +13 −2 Original line number Diff line number Diff line Loading @@ -89,7 +89,8 @@ def get_checkpoint_tracker_filename(checkpoints_path): return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt') def save_checkpoint(iteration, model, optimizer, lr_scheduler): def save_checkpoint(iteration, model, optimizer, lr_scheduler, consumed_train_samples=None, consumed_valid_samples=None): """Save a model checkpoint.""" args = get_args() Loading @@ -103,6 +104,10 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): state_dict['args'] = args state_dict['checkpoint_version'] = 2.0 state_dict['iteration'] = iteration if consumed_train_samples: state_dict['consumed_train_samples'] = consumed_train_samples if consumed_valid_samples: state_dict['consumed_valid_samples'] = consumed_valid_samples state_dict['model'] = model.state_dict_for_save_checkpoint() # Optimizer stuff. Loading Loading @@ -214,6 +219,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): checkpoint_name)) sys.exit() if 'consumed_train_samples' in state_dict: assert args.consumed_train_samples == 0 args.consumed_train_samples = state_dict['consumed_train_samples'] if 'consumed_valid_samples' in state_dict: assert args.consumed_valid_samples == 0 args.consumed_valid_samples = state_dict['consumed_valid_samples'] # Check arguments. if 'args' in state_dict: Loading
megatron/data/sampler.py→megatron/data/data_loaders.py +32 −1 Original line number Diff line number Diff line Loading @@ -13,7 +13,38 @@ # See the License for the specific language governing permissions and # limitations under the License. """Megatorn Sampler.""" """Dataloaders.""" import torch from megatron import get_args from megatron import mpu def build_pretraining_data_loader(dataset, consumed_samples): """Buld dataloader given an input dataset.""" if dataset is None: return None args = get_args() world_size = mpu.get_data_parallel_world_size() global_batch_size = args.batch_size * world_size # Megatron sampler batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=consumed_samples, global_batch_size=global_batch_size, rank=mpu.get_data_parallel_rank(), world_size=world_size) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) class MegatronPretrainingSampler: Loading
megatron/training.py +29 −13 Original line number Diff line number Diff line Loading @@ -37,7 +37,7 @@ from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import get_params_for_weight_decay_optimization from megatron.model.realm_model import ICTBertModel from megatron.utils import check_adlr_autoresume_termination from megatron.utils import make_data_loader from megatron.data.data_loaders import build_pretraining_data_loader from megatron.utils import report_memory Loading Loading @@ -104,7 +104,9 @@ def pretrain(train_valid_test_dataset_provider, model_provider, iteration, False) if args.save and iteration != 0: save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, lr_scheduler, consumed_train_samples=args.consumed_train_samples, consumed_valid_samples=args.consumed_valid_samples) if args.do_test: # Run on test data. Loading Loading @@ -224,7 +226,8 @@ def setup_model_and_optimizer(model_provider_func): while hasattr(unwrapped_model, 'module'): unwrapped_model = unwrapped_model.module if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'): if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'): print("Initializing ICT from pretrained BERT model", flush=True) unwrapped_model.init_state_dict_from_bert() Loading Loading @@ -414,6 +417,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler, optimizer, lr_scheduler) iteration += 1 args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ args.batch_size # Logging. loss_scale = None Loading @@ -433,7 +438,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler, # Checkpointing if args.save and args.save_interval and \ iteration % args.save_interval == 0: save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, lr_scheduler, consumed_train_samples=args.consumed_train_samples, consumed_valid_samples=args.consumed_valid_samples) # Evaluation if args.eval_interval and iteration % args.eval_interval == 0 and \ Loading Loading @@ -472,6 +479,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): args.eval_iters)) # Forward evaluation. _, loss_dict = forward_step_func(data_iterator, model) args.consumed_valid_samples += mpu.get_data_parallel_world_size() \ * args.batch_size # Reduce across processes. for key in loss_dict: total_loss_dict[key] = total_loss_dict.get(key, 0.) + \ Loading Loading @@ -517,11 +526,19 @@ def build_train_valid_test_data_iterators( (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) print_rank_0('> building train, validation, and test datasets ...') # Data loader only on rank 0 of each model parallel group. if mpu.get_model_parallel_rank() == 0: # Rank, size, and global batch size. # Rank and global batch size. data_parallel_size = mpu.get_data_parallel_world_size() global_batch_size = args.batch_size * data_parallel_size # Backward compatibility, assume fixed batch size. if args.iteration > 0 and args.consumed_train_samples == 0: args.consumed_train_samples = args.iteration * global_batch_size if args.iteration > 0 and args.consumed_valid_samples == 0: args.consumed_valid_samples = (args.iteration // args.eval_interval) * \ args.eval_iters * global_batch_size # Data loader only on rank 0 of each model parallel group. if mpu.get_model_parallel_rank() == 0: # Number of train/valid/test samples. train_iters = args.train_iters Loading @@ -540,12 +557,11 @@ def build_train_valid_test_data_iterators( train_val_test_num_samples) # Build dataloders. comsumed_samples = args.iteration * global_batch_size train_dataloader = make_data_loader(train_ds, comsumed_samples) comsumed_samples = (args.iteration // args.eval_interval) * \ args.eval_iters * global_batch_size valid_dataloader = make_data_loader(valid_ds, comsumed_samples) test_dataloader = make_data_loader(test_ds, comsumed_samples) train_dataloader = build_pretraining_data_loader( train_ds, args.consumed_train_samples) valid_dataloader = build_pretraining_data_loader( valid_ds, args.consumed_valid_samples) test_dataloader = build_pretraining_data_loader(test_ds, 0) # Flags to know if we need to do training/validation/testing. do_train = train_dataloader is not None and args.train_iters > 0 Loading
megatron/utils.py +0 −25 Original line number Diff line number Diff line Loading @@ -24,7 +24,6 @@ from megatron import print_rank_0 from megatron import get_adlr_autoresume from megatron import mpu from megatron.checkpointing import save_checkpoint from megatron.data.sampler import MegatronPretrainingSampler from megatron.fp16 import FP16_Optimizer Loading Loading @@ -89,30 +88,6 @@ def check_adlr_autoresume_termination(iteration, model, sys.exit(0) def make_data_loader(dataset, consumed_samples): """Buld dataloader given an input dataset.""" if dataset is None: return None args = get_args() # Data parallel arguments. world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() global_batch_size = args.batch_size * world_size num_workers = args.num_workers # Megatron sampler batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=consumed_samples, global_batch_size=global_batch_size, rank=rank, world_size=world_size) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) def get_ltor_masks_and_position_ids(data, eod_token, reset_position_ids, Loading