Commit 6172f198 authored by Wang, Xiao's avatar Wang, Xiao
Browse files

Rewrote the comments for data_prallel_src_rank function. It represents which...

Rewrote the comments for data_prallel_src_rank function. It represents which data parallel group the GPU belongs to
parent 65ed158f
Loading
Loading
Loading
Loading
+28 −2
Original line number Diff line number Diff line
@@ -106,10 +106,20 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
    data_parallel_size = world_size // (tensor_model_parallel_size *
                                        pipeline_model_parallel_size)


    print("XW: world_size ",world_size, " tensor_model_size ",tensor_model_parallel_size, " pipeline_model_size ",pipeline_model_parallel_size," data_parallel_size ",data_parallel_size,flush=True)


    num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
    num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
    num_data_parallel_groups = world_size // data_parallel_size



    print("XW: num_tensor_groups ",num_tensor_model_parallel_groups, " num_pipeline_groups ",num_pipeline_model_parallel_groups, " num_data_groups ",num_data_parallel_groups,flush=True)



    if virtual_pipeline_model_parallel_size_ is not None:
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
@@ -122,6 +132,9 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,

    rank = torch.distributed.get_rank()

    print("XW: rank ",rank,flush=True)


    # Build the data-parallel groups.
    global _DATA_PARALLEL_GROUP
    assert _DATA_PARALLEL_GROUP is None, \
@@ -133,11 +146,17 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
        for j in range(tensor_model_parallel_size):
            ranks = range(start_rank + j, end_rank,
                          tensor_model_parallel_size)

            print("XW data_parrallel: i ",i," j ",j," ranks ",ranks)

            all_data_parallel_group_ranks.append(list(ranks))
            group = torch.distributed.new_group(ranks)

            if rank in ranks:
                _DATA_PARALLEL_GROUP = group

    print("XW: all_data_parallel_group_ranks ",all_data_parallel_group_ranks,flush=True)

    # Build the model-parallel groups.
    global _MODEL_PARALLEL_GROUP
    assert _MODEL_PARALLEL_GROUP is None, \
@@ -145,6 +164,9 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
    for i in range(data_parallel_size):
        ranks = [data_parallel_group_ranks[i]
                 for data_parallel_group_ranks in all_data_parallel_group_ranks]

        print("XW model_parallel: i ",i," ranks ",ranks)

        group = torch.distributed.new_group(ranks)
        if rank in ranks:
            _MODEL_PARALLEL_GROUP = group
@@ -156,6 +178,11 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
    for i in range(num_tensor_model_parallel_groups):
        ranks = range(i * tensor_model_parallel_size,
                      (i + 1) * tensor_model_parallel_size)


        print("XW tensor_parrallel: i ",i," ranks [",*ranks,"]")  


        group = torch.distributed.new_group(ranks)
        if rank in ranks:
            _TENSOR_MODEL_PARALLEL_GROUP = group
@@ -453,8 +480,7 @@ def get_tensor_model_parallel_src_rank():


def get_data_parallel_src_rank():
    """Calculate the global rank corresponding to the first local rank
    in the tensor model parallel group."""
    """XW: calculate which data parallel group the GPU belongs to"""
    global_rank = torch.distributed.get_rank()
    data_parallel_size = get_data_parallel_world_size()
    num_data_parallel_groups = torch.distributed.get_world_size() // data_parallel_size