Loading megatron/checkpointing.py +4 −6 Original line number Diff line number Diff line Loading @@ -154,13 +154,11 @@ def get_rng_state(): if torch.distributed.is_initialized() and \ mpu.get_data_parallel_world_size() > 1 and \ args.data_parallel_random_init: if mpu.get_data_parallel_rank() == 0: rng_state_list = \ [None for i in range(mpu.get_data_parallel_world_size())] torch.distributed.gather_object( rng_state, torch.distributed.all_gather_object( rng_state_list, dst=mpu.get_data_parallel_src_rank(), rng_state, group=mpu.get_data_parallel_group()) else: rng_state_list = [rng_state] Loading Loading
megatron/checkpointing.py +4 −6 Original line number Diff line number Diff line Loading @@ -154,13 +154,11 @@ def get_rng_state(): if torch.distributed.is_initialized() and \ mpu.get_data_parallel_world_size() > 1 and \ args.data_parallel_random_init: if mpu.get_data_parallel_rank() == 0: rng_state_list = \ [None for i in range(mpu.get_data_parallel_world_size())] torch.distributed.gather_object( rng_state, torch.distributed.all_gather_object( rng_state_list, dst=mpu.get_data_parallel_src_rank(), rng_state, group=mpu.get_data_parallel_group()) else: rng_state_list = [rng_state] Loading