Loading megatron/checkpointing.py +5 −13 Original line number Diff line number Diff line Loading @@ -89,8 +89,7 @@ def get_checkpoint_tracker_filename(checkpoints_path): return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt') def save_checkpoint(iteration, model, optimizer, lr_scheduler, consumed_train_samples=None, consumed_valid_samples=None): def save_checkpoint(iteration, model, optimizer, lr_scheduler): """Save a model checkpoint.""" args = get_args() Loading @@ -104,10 +103,6 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, state_dict['args'] = args state_dict['checkpoint_version'] = 2.0 state_dict['iteration'] = iteration if consumed_train_samples: state_dict['consumed_train_samples'] = consumed_train_samples if consumed_valid_samples: state_dict['consumed_valid_samples'] = consumed_valid_samples state_dict['model'] = model.state_dict_for_save_checkpoint() # Optimizer stuff. Loading Loading @@ -219,17 +214,14 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): checkpoint_name)) sys.exit() if 'consumed_train_samples' in state_dict: # Check arguments. assert args.consumed_train_samples == 0 args.consumed_train_samples = state_dict['consumed_train_samples'] if 'consumed_valid_samples' in state_dict: assert args.consumed_valid_samples == 0 args.consumed_valid_samples = state_dict['consumed_valid_samples'] # Check arguments. if 'args' in state_dict: checkpoint_args = state_dict['args'] check_checkpoint_args(checkpoint_args) args.consumed_train_samples = getattr(args, 'consumed_train_samples', 0) args.consumed_valid_samples = getattr(args, 'consumed_valid_samples', 0) else: print_rank_0('could not find arguments in the checkpoint ...') Loading megatron/training.py +2 −6 Original line number Diff line number Diff line Loading @@ -104,9 +104,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider, iteration, False) if args.save and iteration != 0: save_checkpoint(iteration, model, optimizer, lr_scheduler, consumed_train_samples=args.consumed_train_samples, consumed_valid_samples=args.consumed_valid_samples) save_checkpoint(iteration, model, optimizer, lr_scheduler) if args.do_test: # Run on test data. Loading Loading @@ -438,9 +436,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, # Checkpointing if args.save and args.save_interval and \ iteration % args.save_interval == 0: save_checkpoint(iteration, model, optimizer, lr_scheduler, consumed_train_samples=args.consumed_train_samples, consumed_valid_samples=args.consumed_valid_samples) save_checkpoint(iteration, model, optimizer, lr_scheduler) # Evaluation if args.eval_interval and iteration % args.eval_interval == 0 and \ Loading Loading
megatron/checkpointing.py +5 −13 Original line number Diff line number Diff line Loading @@ -89,8 +89,7 @@ def get_checkpoint_tracker_filename(checkpoints_path): return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt') def save_checkpoint(iteration, model, optimizer, lr_scheduler, consumed_train_samples=None, consumed_valid_samples=None): def save_checkpoint(iteration, model, optimizer, lr_scheduler): """Save a model checkpoint.""" args = get_args() Loading @@ -104,10 +103,6 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, state_dict['args'] = args state_dict['checkpoint_version'] = 2.0 state_dict['iteration'] = iteration if consumed_train_samples: state_dict['consumed_train_samples'] = consumed_train_samples if consumed_valid_samples: state_dict['consumed_valid_samples'] = consumed_valid_samples state_dict['model'] = model.state_dict_for_save_checkpoint() # Optimizer stuff. Loading Loading @@ -219,17 +214,14 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): checkpoint_name)) sys.exit() if 'consumed_train_samples' in state_dict: # Check arguments. assert args.consumed_train_samples == 0 args.consumed_train_samples = state_dict['consumed_train_samples'] if 'consumed_valid_samples' in state_dict: assert args.consumed_valid_samples == 0 args.consumed_valid_samples = state_dict['consumed_valid_samples'] # Check arguments. if 'args' in state_dict: checkpoint_args = state_dict['args'] check_checkpoint_args(checkpoint_args) args.consumed_train_samples = getattr(args, 'consumed_train_samples', 0) args.consumed_valid_samples = getattr(args, 'consumed_valid_samples', 0) else: print_rank_0('could not find arguments in the checkpoint ...') Loading
megatron/training.py +2 −6 Original line number Diff line number Diff line Loading @@ -104,9 +104,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider, iteration, False) if args.save and iteration != 0: save_checkpoint(iteration, model, optimizer, lr_scheduler, consumed_train_samples=args.consumed_train_samples, consumed_valid_samples=args.consumed_valid_samples) save_checkpoint(iteration, model, optimizer, lr_scheduler) if args.do_test: # Run on test data. Loading Loading @@ -438,9 +436,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, # Checkpointing if args.save and args.save_interval and \ iteration % args.save_interval == 0: save_checkpoint(iteration, model, optimizer, lr_scheduler, consumed_train_samples=args.consumed_train_samples, consumed_valid_samples=args.consumed_valid_samples) save_checkpoint(iteration, model, optimizer, lr_scheduler) # Evaluation if args.eval_interval and iteration % args.eval_interval == 0 and \ Loading