Loading megatron/checkpointing.py +52 −15 Original line number Diff line number Diff line Loading @@ -140,6 +140,32 @@ def read_metadata(tracker_filename): return max_iter, release def get_rng_state(): """ collect rng state across data parallel ranks """ rng_state = { 'random_rng_state': random.getstate(), 'np_rng_state': np.random.get_state(), 'torch_rng_state': torch.get_rng_state(), 'cuda_rng_state': torch.cuda.get_rng_state(), 'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states()} rng_state_list = None if torch.distributed.is_initialized() and \ mpu.get_data_parallel_world_size() > 1: if mpu.get_data_parallel_rank() == 0: rng_state_list = \ [None for i in range(mpu.get_data_parallel_world_size())] torch.distributed.gather_object( rng_state, rng_state_list, dst=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()) else: rng_state_list = [rng_state] return rng_state_list def save_checkpoint(iteration, model, optimizer, lr_scheduler): """Save a model checkpoint.""" args = get_args() Loading @@ -150,6 +176,9 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( iteration, args.save)) # collect rng state across data parallel ranks rng_state = get_rng_state() if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0: # Arguments, iteration, and model. Loading @@ -173,12 +202,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): # RNG states. if not args.no_save_rng: state_dict['random_rng_state'] = random.getstate() state_dict['np_rng_state'] = np.random.get_state() state_dict['torch_rng_state'] = torch.get_rng_state() state_dict['cuda_rng_state'] = torch.cuda.get_rng_state() state_dict['rng_tracker_states'] \ = mpu.get_cuda_rng_tracker().get_states() state_dict["rng_state"] = rng_state # Save. checkpoint_name = get_checkpoint_name(args.save, iteration) Loading Loading @@ -381,6 +405,19 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True # rng states. if not release and not args.finetune and not args.no_load_rng: try: if 'rng_state' in state_dict: # access rng_state for data parallel rank rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()] random.setstate(rng_state['random_rng_state']) np.random.set_state(rng_state['np_rng_state']) torch.set_rng_state(rng_state['torch_rng_state']) torch.cuda.set_rng_state(rng_state['cuda_rng_state']) # Check for empty states array if not rng_state['rng_tracker_states']: raise KeyError mpu.get_cuda_rng_tracker().set_states( state_dict['rng_tracker_states']) else: # backward compatability random.setstate(state_dict['random_rng_state']) np.random.set_state(state_dict['np_rng_state']) torch.set_rng_state(state_dict['torch_rng_state']) Loading Loading
megatron/checkpointing.py +52 −15 Original line number Diff line number Diff line Loading @@ -140,6 +140,32 @@ def read_metadata(tracker_filename): return max_iter, release def get_rng_state(): """ collect rng state across data parallel ranks """ rng_state = { 'random_rng_state': random.getstate(), 'np_rng_state': np.random.get_state(), 'torch_rng_state': torch.get_rng_state(), 'cuda_rng_state': torch.cuda.get_rng_state(), 'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states()} rng_state_list = None if torch.distributed.is_initialized() and \ mpu.get_data_parallel_world_size() > 1: if mpu.get_data_parallel_rank() == 0: rng_state_list = \ [None for i in range(mpu.get_data_parallel_world_size())] torch.distributed.gather_object( rng_state, rng_state_list, dst=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()) else: rng_state_list = [rng_state] return rng_state_list def save_checkpoint(iteration, model, optimizer, lr_scheduler): """Save a model checkpoint.""" args = get_args() Loading @@ -150,6 +176,9 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( iteration, args.save)) # collect rng state across data parallel ranks rng_state = get_rng_state() if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0: # Arguments, iteration, and model. Loading @@ -173,12 +202,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): # RNG states. if not args.no_save_rng: state_dict['random_rng_state'] = random.getstate() state_dict['np_rng_state'] = np.random.get_state() state_dict['torch_rng_state'] = torch.get_rng_state() state_dict['cuda_rng_state'] = torch.cuda.get_rng_state() state_dict['rng_tracker_states'] \ = mpu.get_cuda_rng_tracker().get_states() state_dict["rng_state"] = rng_state # Save. checkpoint_name = get_checkpoint_name(args.save, iteration) Loading Loading @@ -381,6 +405,19 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True # rng states. if not release and not args.finetune and not args.no_load_rng: try: if 'rng_state' in state_dict: # access rng_state for data parallel rank rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()] random.setstate(rng_state['random_rng_state']) np.random.set_state(rng_state['np_rng_state']) torch.set_rng_state(rng_state['torch_rng_state']) torch.cuda.set_rng_state(rng_state['cuda_rng_state']) # Check for empty states array if not rng_state['rng_tracker_states']: raise KeyError mpu.get_cuda_rng_tracker().set_states( state_dict['rng_tracker_states']) else: # backward compatability random.setstate(state_dict['random_rng_state']) np.random.set_state(state_dict['np_rng_state']) torch.set_rng_state(state_dict['torch_rng_state']) Loading