Commit 930ec4a2 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'fix_cpu_init' into 'main'

Fix cpu initialization for mp=1 and fp16

See merge request ADLR/megatron-lm!123
parents 42d2be06 04cd82bd
Loading
Loading
Loading
Loading
+1 −8
Original line number Diff line number Diff line
@@ -72,14 +72,6 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
    weight.partition_dim = partition_dim
    weight.partition_stride = stride

    # If we only use 1 process for model parallelism, bypass scatter.
    world_size = get_model_parallel_world_size()
    if world_size == 1:
        init_method(weight)
        if return_master_weight:
            return weight
        return None

    # Initialize master weight
    master_weight = torch.empty(output_size, input_size,
                                dtype=torch.float,
@@ -93,6 +85,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
    weight_list = torch.split(master_weight, per_partition_per_stride_size,
                              dim=partition_dim)
    rank = get_model_parallel_rank()
    world_size = get_model_parallel_world_size()
    my_weight_list = weight_list[rank::world_size]

    with torch.no_grad():