Commit 8aa4619f authored by Jared Casper's avatar Jared Casper
Browse files

Cherry-pick fix from development repo.

parent 1b3dfa2f
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -109,7 +109,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
    per_partition_per_stride_size = divide(per_partition_size, stride)
    weight_list = torch.split(master_weight, per_partition_per_stride_size,
                              dim=partition_dim)
    rank = get_model_parallel_rank()
    rank = get_tensor_model_parallel_rank()
    world_size = get_tensor_model_parallel_world_size()
    my_weight_list = weight_list[rank::world_size]