Loading megatron/mpu/layers.py +1 −8 Original line number Diff line number Diff line Loading @@ -72,14 +72,6 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, weight.partition_dim = partition_dim weight.partition_stride = stride # If we only use 1 process for model parallelism, bypass scatter. world_size = get_model_parallel_world_size() if world_size == 1: init_method(weight) if return_master_weight: return weight return None # Initialize master weight master_weight = torch.empty(output_size, input_size, dtype=torch.float, Loading @@ -93,6 +85,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim) rank = get_model_parallel_rank() world_size = get_model_parallel_world_size() my_weight_list = weight_list[rank::world_size] with torch.no_grad(): Loading Loading
megatron/mpu/layers.py +1 −8 Original line number Diff line number Diff line Loading @@ -72,14 +72,6 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, weight.partition_dim = partition_dim weight.partition_stride = stride # If we only use 1 process for model parallelism, bypass scatter. world_size = get_model_parallel_world_size() if world_size == 1: init_method(weight) if return_master_weight: return weight return None # Initialize master weight master_weight = torch.empty(output_size, input_size, dtype=torch.float, Loading @@ -93,6 +85,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim) rank = get_model_parallel_rank() world_size = get_model_parallel_world_size() my_weight_list = weight_list[rank::world_size] with torch.no_grad(): Loading