Loading megatron/mpu/initialize.py +28 −2 Original line number Diff line number Diff line Loading @@ -106,10 +106,20 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size) print("XW: world_size ",world_size, " tensor_model_size ",tensor_model_parallel_size, " pipeline_model_size ",pipeline_model_parallel_size," data_parallel_size ",data_parallel_size,flush=True) num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size num_data_parallel_groups = world_size // data_parallel_size print("XW: num_tensor_groups ",num_tensor_model_parallel_groups, " num_pipeline_groups ",num_pipeline_model_parallel_groups, " num_data_groups ",num_data_parallel_groups,flush=True) if virtual_pipeline_model_parallel_size_ is not None: global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE Loading @@ -122,6 +132,9 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, rank = torch.distributed.get_rank() print("XW: rank ",rank,flush=True) # Build the data-parallel groups. global _DATA_PARALLEL_GROUP assert _DATA_PARALLEL_GROUP is None, \ Loading @@ -133,11 +146,17 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, for j in range(tensor_model_parallel_size): ranks = range(start_rank + j, end_rank, tensor_model_parallel_size) print("XW data_parrallel: i ",i," j ",j," ranks ",ranks) all_data_parallel_group_ranks.append(list(ranks)) group = torch.distributed.new_group(ranks) if rank in ranks: _DATA_PARALLEL_GROUP = group print("XW: all_data_parallel_group_ranks ",all_data_parallel_group_ranks,flush=True) # Build the model-parallel groups. global _MODEL_PARALLEL_GROUP assert _MODEL_PARALLEL_GROUP is None, \ Loading @@ -145,6 +164,9 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, for i in range(data_parallel_size): ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks] print("XW model_parallel: i ",i," ranks ",ranks) group = torch.distributed.new_group(ranks) if rank in ranks: _MODEL_PARALLEL_GROUP = group Loading @@ -156,6 +178,11 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, for i in range(num_tensor_model_parallel_groups): ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) print("XW tensor_parrallel: i ",i," ranks [",*ranks,"]") group = torch.distributed.new_group(ranks) if rank in ranks: _TENSOR_MODEL_PARALLEL_GROUP = group Loading Loading @@ -453,8 +480,7 @@ def get_tensor_model_parallel_src_rank(): def get_data_parallel_src_rank(): """Calculate the global rank corresponding to the first local rank in the tensor model parallel group.""" """XW: calculate which data parallel group the GPU belongs to""" global_rank = torch.distributed.get_rank() data_parallel_size = get_data_parallel_world_size() num_data_parallel_groups = torch.distributed.get_world_size() // data_parallel_size Loading Loading
megatron/mpu/initialize.py +28 −2 Original line number Diff line number Diff line Loading @@ -106,10 +106,20 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size) print("XW: world_size ",world_size, " tensor_model_size ",tensor_model_parallel_size, " pipeline_model_size ",pipeline_model_parallel_size," data_parallel_size ",data_parallel_size,flush=True) num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size num_data_parallel_groups = world_size // data_parallel_size print("XW: num_tensor_groups ",num_tensor_model_parallel_groups, " num_pipeline_groups ",num_pipeline_model_parallel_groups, " num_data_groups ",num_data_parallel_groups,flush=True) if virtual_pipeline_model_parallel_size_ is not None: global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE Loading @@ -122,6 +132,9 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, rank = torch.distributed.get_rank() print("XW: rank ",rank,flush=True) # Build the data-parallel groups. global _DATA_PARALLEL_GROUP assert _DATA_PARALLEL_GROUP is None, \ Loading @@ -133,11 +146,17 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, for j in range(tensor_model_parallel_size): ranks = range(start_rank + j, end_rank, tensor_model_parallel_size) print("XW data_parrallel: i ",i," j ",j," ranks ",ranks) all_data_parallel_group_ranks.append(list(ranks)) group = torch.distributed.new_group(ranks) if rank in ranks: _DATA_PARALLEL_GROUP = group print("XW: all_data_parallel_group_ranks ",all_data_parallel_group_ranks,flush=True) # Build the model-parallel groups. global _MODEL_PARALLEL_GROUP assert _MODEL_PARALLEL_GROUP is None, \ Loading @@ -145,6 +164,9 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, for i in range(data_parallel_size): ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks] print("XW model_parallel: i ",i," ranks ",ranks) group = torch.distributed.new_group(ranks) if rank in ranks: _MODEL_PARALLEL_GROUP = group Loading @@ -156,6 +178,11 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, for i in range(num_tensor_model_parallel_groups): ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) print("XW tensor_parrallel: i ",i," ranks [",*ranks,"]") group = torch.distributed.new_group(ranks) if rank in ranks: _TENSOR_MODEL_PARALLEL_GROUP = group Loading Loading @@ -453,8 +480,7 @@ def get_tensor_model_parallel_src_rank(): def get_data_parallel_src_rank(): """Calculate the global rank corresponding to the first local rank in the tensor model parallel group.""" """XW: calculate which data parallel group the GPU belongs to""" global_rank = torch.distributed.get_rank() data_parallel_size = get_data_parallel_world_size() num_data_parallel_groups = torch.distributed.get_world_size() // data_parallel_size Loading