Commit b037a69e authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

using all_gather instead of gather (nccl does not support gather)

parent a7ee77ea
Loading
Loading
Loading
Loading
+4 −6
Original line number Diff line number Diff line
@@ -154,13 +154,11 @@ def get_rng_state():
    if torch.distributed.is_initialized() and \
            mpu.get_data_parallel_world_size() > 1 and \
            args.data_parallel_random_init:
        if mpu.get_data_parallel_rank() == 0:
        rng_state_list = \
            [None for i in range(mpu.get_data_parallel_world_size())]
        torch.distributed.gather_object(
            rng_state,
        torch.distributed.all_gather_object(
            rng_state_list,
            dst=mpu.get_data_parallel_src_rank(),
            rng_state,
            group=mpu.get_data_parallel_group())
    else:
        rng_state_list = [rng_state]