Commit 09d38884 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

checkpointing rng_state of all data parallel ranks

parent 343dc97a
Loading
Loading
Loading
Loading
+52 −15
Original line number Diff line number Diff line
@@ -140,6 +140,32 @@ def read_metadata(tracker_filename):
    return max_iter, release


def get_rng_state():
    """ collect rng state across data parallel ranks """
    rng_state = {
        'random_rng_state': random.getstate(),
        'np_rng_state': np.random.get_state(),
        'torch_rng_state': torch.get_rng_state(),
        'cuda_rng_state': torch.cuda.get_rng_state(),
        'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states()}

    rng_state_list = None
    if torch.distributed.is_initialized() and \
            mpu.get_data_parallel_world_size() > 1:
        if mpu.get_data_parallel_rank() == 0:
            rng_state_list = \
                [None for i in range(mpu.get_data_parallel_world_size())]
        torch.distributed.gather_object(
            rng_state,
            rng_state_list,
            dst=mpu.get_data_parallel_src_rank(),
            group=mpu.get_data_parallel_group())
    else:
        rng_state_list = [rng_state]

    return rng_state_list


def save_checkpoint(iteration, model, optimizer, lr_scheduler):
    """Save a model checkpoint."""
    args = get_args()
@@ -150,6 +176,9 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
    print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
        iteration, args.save))

    # collect rng state across data parallel ranks
    rng_state = get_rng_state()

    if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0:

        # Arguments, iteration, and model.
@@ -173,12 +202,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):

        # RNG states.
        if not args.no_save_rng:
            state_dict['random_rng_state'] = random.getstate()
            state_dict['np_rng_state'] = np.random.get_state()
            state_dict['torch_rng_state'] = torch.get_rng_state()
            state_dict['cuda_rng_state'] = torch.cuda.get_rng_state()
            state_dict['rng_tracker_states'] \
                = mpu.get_cuda_rng_tracker().get_states()
            state_dict["rng_state"] = rng_state

        # Save.
        checkpoint_name = get_checkpoint_name(args.save, iteration)
@@ -381,6 +405,19 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
            if 'rng_state' in state_dict:
                # access rng_state for data parallel rank
                rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()]
                random.setstate(rng_state['random_rng_state'])
                np.random.set_state(rng_state['np_rng_state'])
                torch.set_rng_state(rng_state['torch_rng_state'])
                torch.cuda.set_rng_state(rng_state['cuda_rng_state'])
                # Check for empty states array
                if not rng_state['rng_tracker_states']:
                    raise KeyError
                mpu.get_cuda_rng_tracker().set_states(
                    state_dict['rng_tracker_states'])
            else:  # backward compatability
                random.setstate(state_dict['random_rng_state'])
                np.random.set_state(state_dict['np_rng_state'])
                torch.set_rng_state(state_dict['torch_rng_state'])