Loading megatron/mpu/layers.py +1 −1 Original line number Diff line number Diff line Loading @@ -109,7 +109,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, per_partition_per_stride_size = divide(per_partition_size, stride) weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim) rank = get_model_parallel_rank() rank = get_tensor_model_parallel_rank() world_size = get_tensor_model_parallel_world_size() my_weight_list = weight_list[rank::world_size] Loading Loading
megatron/mpu/layers.py +1 −1 Original line number Diff line number Diff line Loading @@ -109,7 +109,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, per_partition_per_stride_size = divide(per_partition_size, stride) weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim) rank = get_model_parallel_rank() rank = get_tensor_model_parallel_rank() world_size = get_tensor_model_parallel_world_size() my_weight_list = weight_list[rank::world_size] Loading