Commit 393249bc authored by Wang, Xiao's avatar Wang, Xiao
Browse files

add comments for embedding ranks and position embedding ranks

parent 6172f198
Loading
Loading
Loading
Loading
+4 −20
Original line number Diff line number Diff line
@@ -107,19 +107,11 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
                                        pipeline_model_parallel_size)


    print("XW: world_size ",world_size, " tensor_model_size ",tensor_model_parallel_size, " pipeline_model_size ",pipeline_model_parallel_size," data_parallel_size ",data_parallel_size,flush=True)


    num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
    num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
    num_data_parallel_groups = world_size // data_parallel_size



    print("XW: num_tensor_groups ",num_tensor_model_parallel_groups, " num_pipeline_groups ",num_pipeline_model_parallel_groups, " num_data_groups ",num_data_parallel_groups,flush=True)



    if virtual_pipeline_model_parallel_size_ is not None:
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
@@ -132,8 +124,6 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,

    rank = torch.distributed.get_rank()

    print("XW: rank ",rank,flush=True)


    # Build the data-parallel groups.
    global _DATA_PARALLEL_GROUP
@@ -147,15 +137,12 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
            ranks = range(start_rank + j, end_rank,
                          tensor_model_parallel_size)

            print("XW data_parrallel: i ",i," j ",j," ranks ",ranks)

            all_data_parallel_group_ranks.append(list(ranks))
            group = torch.distributed.new_group(ranks)

            if rank in ranks:
                _DATA_PARALLEL_GROUP = group

    print("XW: all_data_parallel_group_ranks ",all_data_parallel_group_ranks,flush=True)

    # Build the model-parallel groups.
    global _MODEL_PARALLEL_GROUP
@@ -165,7 +152,6 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
        ranks = [data_parallel_group_ranks[i]
                 for data_parallel_group_ranks in all_data_parallel_group_ranks]

        print("XW model_parallel: i ",i," ranks ",ranks)

        group = torch.distributed.new_group(ranks)
        if rank in ranks:
@@ -179,10 +165,6 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
        ranks = range(i * tensor_model_parallel_size,
                      (i + 1) * tensor_model_parallel_size)


        print("XW tensor_parrallel: i ",i," ranks [",*ranks,"]")  


        group = torch.distributed.new_group(ranks)
        if rank in ranks:
            _TENSOR_MODEL_PARALLEL_GROUP = group
@@ -205,14 +187,16 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
        ranks = range(i, world_size,
                      num_pipeline_model_parallel_groups)
        group = torch.distributed.new_group(ranks)

        if rank in ranks:
            _PIPELINE_MODEL_PARALLEL_GROUP = group
            _PIPELINE_GLOBAL_RANKS = ranks
        # Setup embedding group (to exchange gradients between
        # first and last stages).
        if len(ranks) > 1:
            embedding_ranks = [ranks[0], ranks[-1]]
            position_embedding_ranks = [ranks[0]]
            embedding_ranks = [ranks[0], ranks[-1]]   #For GPUs in the same pipeline model parallel group, embedding group is the two GPUs at the first sub-model and the last sub-model
            position_embedding_ranks = [ranks[0]]   # Only the GPUs assigned with the first pipeline sub-model

            if pipeline_model_parallel_split_rank_ is not None:
                if ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks:
                    embedding_ranks = [ranks[0],