Loading megatron/checkpointing.py +3 −1 Original line number Diff line number Diff line Loading @@ -65,6 +65,8 @@ def check_checkpoint_args(checkpoint_args): _compare('make_vocab_size_divisible_by') _compare('padded_vocab_size') _compare('tokenizer_type') if args.data_parallel_random_init: _compare('data_parallel_random_init') if get_checkpoint_version() < 3.0: _compare('tensor_model_parallel_size', old_arg_name='model_parallel_size') Loading @@ -72,7 +74,6 @@ def check_checkpoint_args(checkpoint_args): _compare('tensor_model_parallel_size') _compare('pipeline_model_parallel_size') def ensure_directory_exists(filename): """Build filename's path if it does not already exists.""" dirname = os.path.dirname(filename) Loading Loading @@ -408,6 +409,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True if 'rng_state' in state_dict: # access rng_state for data parallel rank if args.data_parallel_random_init: rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()] else: rng_state = state_dict['rng_state'][0] Loading Loading
megatron/checkpointing.py +3 −1 Original line number Diff line number Diff line Loading @@ -65,6 +65,8 @@ def check_checkpoint_args(checkpoint_args): _compare('make_vocab_size_divisible_by') _compare('padded_vocab_size') _compare('tokenizer_type') if args.data_parallel_random_init: _compare('data_parallel_random_init') if get_checkpoint_version() < 3.0: _compare('tensor_model_parallel_size', old_arg_name='model_parallel_size') Loading @@ -72,7 +74,6 @@ def check_checkpoint_args(checkpoint_args): _compare('tensor_model_parallel_size') _compare('pipeline_model_parallel_size') def ensure_directory_exists(filename): """Build filename's path if it does not already exists.""" dirname = os.path.dirname(filename) Loading Loading @@ -408,6 +409,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True if 'rng_state' in state_dict: # access rng_state for data parallel rank if args.data_parallel_random_init: rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()] else: rng_state = state_dict['rng_state'][0] Loading