Commit 798b6a64 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

address review comments

parent b037a69e
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -65,6 +65,8 @@ def check_checkpoint_args(checkpoint_args):
        _compare('make_vocab_size_divisible_by')
        _compare('padded_vocab_size')
        _compare('tokenizer_type')
    if args.data_parallel_random_init:
        _compare('data_parallel_random_init')
    if get_checkpoint_version() < 3.0:
        _compare('tensor_model_parallel_size',
                 old_arg_name='model_parallel_size')
@@ -72,7 +74,6 @@ def check_checkpoint_args(checkpoint_args):
        _compare('tensor_model_parallel_size')
        _compare('pipeline_model_parallel_size')


def ensure_directory_exists(filename):
    """Build filename's path if it does not already exists."""
    dirname = os.path.dirname(filename)
@@ -408,6 +409,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
            if 'rng_state' in state_dict:
                # access rng_state for data parallel rank
                if args.data_parallel_random_init:

                    rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()]
                else:
                    rng_state = state_dict['rng_state'][0]