Commit e10760ea authored by Boris Fomitchev's avatar Boris Fomitchev
Browse files

Making --lazy_mpu_init act as _USE_CPU_INITIALIZATION

parent d4bb6cb0
Loading
Loading
Loading
Loading
+5 −5
Original line number Diff line number Diff line
@@ -141,7 +141,7 @@ class VocabParallelEmbedding(torch.nn.Module):

        # Allocate weights and initialize.
        args = get_args()
        if _USE_CPU_INITIALIZATION:
        if _USE_CPU_INITIALIZATION or args.lazy_mpu_init:
            self.weight = Parameter(torch.empty(
                self.num_embeddings_per_partition, self.embedding_dim,
                dtype=args.params_dtype))
@@ -217,7 +217,7 @@ class ColumnParallelLinear(torch.nn.Module):
        # we allocate the transpose.
        # Initialize weight.
        args = get_args()
        if _USE_CPU_INITIALIZATION:
        if _USE_CPU_INITIALIZATION or args.lazy_mpu_init:
            self.weight = Parameter(torch.empty(self.output_size_per_partition,
                                                self.input_size,
                                                dtype=args.params_dtype))
@@ -233,7 +233,7 @@ class ColumnParallelLinear(torch.nn.Module):
                                          partition_dim=0, stride=stride)
            
        if bias:
            if _USE_CPU_INITIALIZATION:
            if _USE_CPU_INITIALIZATION or args.lazy_mpu_init:
                self.bias = Parameter(torch.empty(
                    self.output_size_per_partition, dtype=args.params_dtype))
            else:
@@ -311,7 +311,7 @@ class RowParallelLinear(torch.nn.Module):
        # we allocate the transpose.
        # Initialize weight.
        args = get_args()
        if _USE_CPU_INITIALIZATION:
        if _USE_CPU_INITIALIZATION or args.lazy_mpu_init:
            self.weight = Parameter(torch.empty(self.output_size,
                                                self.input_size_per_partition,
                                                dtype=args.params_dtype))
@@ -326,7 +326,7 @@ class RowParallelLinear(torch.nn.Module):
            _initialize_affine_weight_gpu(self.weight, init_method,
                                          partition_dim=1, stride=stride)
        if bias:
            if _USE_CPU_INITIALIZATION:
            if _USE_CPU_INITIALIZATION or args.lazy_mpu_init:
                self.bias = Parameter(torch.empty(self.output_size,
                                                  dtype=args.params_dtype))
            else: