Loading megatron/checkpointing.py +4 −2 Original line number Diff line number Diff line Loading @@ -220,8 +220,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): if 'args' in state_dict: checkpoint_args = state_dict['args'] check_checkpoint_args(checkpoint_args) args.consumed_train_samples = getattr(args, 'consumed_train_samples', 0) args.consumed_valid_samples = getattr(args, 'consumed_valid_samples', 0) args.consumed_train_samples = getattr(checkpoint_args, 'consumed_train_samples', 0) args.consumed_valid_samples = getattr(checkpoint_args, 'consumed_valid_samples', 0) else: print_rank_0('could not find arguments in the checkpoint ...') Loading Loading
megatron/checkpointing.py +4 −2 Original line number Diff line number Diff line Loading @@ -220,8 +220,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): if 'args' in state_dict: checkpoint_args = state_dict['args'] check_checkpoint_args(checkpoint_args) args.consumed_train_samples = getattr(args, 'consumed_train_samples', 0) args.consumed_valid_samples = getattr(args, 'consumed_valid_samples', 0) args.consumed_train_samples = getattr(checkpoint_args, 'consumed_train_samples', 0) args.consumed_valid_samples = getattr(checkpoint_args, 'consumed_valid_samples', 0) else: print_rank_0('could not find arguments in the checkpoint ...') Loading