Commit 65ed158f authored by Wang, Xiao's avatar Wang, Xiao
Browse files

streamline the code for distributed GPUs initialization

parent 774121f9
Loading
Loading
Loading
Loading
Loading
+11 −18
Original line number Diff line number Diff line
@@ -33,7 +33,6 @@ from megatron.mpu import (set_tensor_model_parallel_rank,
                          set_tensor_model_parallel_world_size)

import os
from mpi4py import MPI
import subprocess

def initialize_megatron(extra_args_provider=None, args_defaults={},
@@ -164,6 +163,7 @@ def _initialize_distributed():
    args = get_args()

    device_count = torch.cuda.device_count()

    if torch.distributed.is_initialized():

        if args.rank == 0:
@@ -187,28 +187,21 @@ def _initialize_distributed():
            torch.cuda.set_device(device)
    
    # Call the init process
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    world_size = comm.Get_size()
    master_addr = None
    if rank == 0:
        hostname_cmd = ["hostname -I"]
        result = subprocess.check_output(hostname_cmd, shell=True)
        master_addr = result.decode('utf-8').split()[0]
    master_addr = comm.bcast(master_addr, root=0)
    proc_name = MPI.Get_processor_name()
    all_procs = comm.allgather(proc_name)
    local_rank = sum([i == proc_name for i in all_procs[:rank]])
    rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
    world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
    local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])

    get_master = "echo $(cat {} | sort | uniq | grep -v batch | grep -v login | head -1)".format(os.environ['LSB_DJOB_HOSTFILE'])
    master_addr=str(subprocess.check_output(get_master, shell=True))[2:-3]

    os.environ['RANK'] = str(rank)
    os.environ['WORLD_SIZE'] = str(world_size)
    os.environ['LOCAL_RANK'] = str(local_rank)
    os.environ['MASTER_ADDR'] = master_addr
    os.environ['MASTER_PORT'] = str(29500)
    init_method=None
    torch.distributed.init_process_group(
        backend=args.distributed_backend,
        timeout=timedelta(minutes=10),
        init_method=init_method)

    torch.distributed.init_process_group( backend=args.distributed_backend,rank=rank,world_size=world_size)


    # Set the tensor model-parallel, pipeline model-parallel, and
    # data-parallel communicators.