Commit 04cd82bd authored by mohammad's avatar mohammad
Browse files

debuged the cpu init for mp=1

parent 42d2be06
Loading
Loading
Loading
Loading
+1 −8
Original line number Diff line number Diff line
@@ -72,14 +72,6 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
    weight.partition_dim = partition_dim
    weight.partition_stride = stride

    # If we only use 1 process for model parallelism, bypass scatter.
    world_size = get_model_parallel_world_size()
    if world_size == 1:
        init_method(weight)
        if return_master_weight:
            return weight
        return None

    # Initialize master weight
    master_weight = torch.empty(output_size, input_size,
                                dtype=torch.float,
@@ -93,6 +85,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
    weight_list = torch.split(master_weight, per_partition_per_stride_size,
                              dim=partition_dim)
    rank = get_model_parallel_rank()
    world_size = get_model_parallel_world_size()
    my_weight_list = weight_list[rank::world_size]

    with torch.no_grad():