Loading megatron/arguments.py +5 −2 Original line number Diff line number Diff line Loading @@ -169,7 +169,7 @@ def _add_network_size_args(parser): help='Use OpenAIs GeLU implementation. This option' 'should not be used unless for backward compatibility' 'reasons.') group.add_argument('--onnx-safe', action='store_true', group.add_argument('--onnx-safe', type=bool, required=False, help='Use workarounds for known problems with Torch ONNX exporter') return parser Loading Loading @@ -335,8 +335,11 @@ def _add_distributed_args(parser): help='local rank passed from distributed launcher.') group.add_argument('--lazy-mpu-init', type=bool, required=False, help='If set to True, initialize_megatron() skips DDP initialization' ' and returns function to complete it instead' ' and returns function to complete it instead.' 'Also turns on --use-cpu-initialization flag.' 'This is for external DDP manager.' ) group.add_argument('--use-cpu-initialization', action='store_true', help='If set, affine parallel weights initialization uses CPU' ) return parser Loading megatron/initialize.py +1 −0 Original line number Diff line number Diff line Loading @@ -62,6 +62,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, args = get_args() if args.lazy_mpu_init: args.use_cpu_initialization=True # delayed initialization of DDP-related stuff # We only set basic DDP globals set_model_parallel_world_size(args.model_parallel_size) Loading megatron/mpu/layers.py +5 −9 Original line number Diff line number Diff line Loading @@ -47,10 +47,6 @@ from .utils import split_tensor_along_last_dim from .utils import VocabUtility from megatron import get_args _USE_CPU_INITIALIZATION = False def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1): """Initialize affine weight for model parallel on GPU.""" Loading Loading @@ -141,7 +137,7 @@ class VocabParallelEmbedding(torch.nn.Module): # Allocate weights and initialize. args = get_args() if _USE_CPU_INITIALIZATION: if args.use_cpu_initialization: self.weight = Parameter(torch.empty( self.num_embeddings_per_partition, self.embedding_dim, dtype=args.params_dtype)) Loading Loading @@ -217,7 +213,7 @@ class ColumnParallelLinear(torch.nn.Module): # we allocate the transpose. # Initialize weight. args = get_args() if _USE_CPU_INITIALIZATION: if args.use_cpu_initialization: self.weight = Parameter(torch.empty(self.output_size_per_partition, self.input_size, dtype=args.params_dtype)) Loading @@ -233,7 +229,7 @@ class ColumnParallelLinear(torch.nn.Module): partition_dim=0, stride=stride) if bias: if _USE_CPU_INITIALIZATION: if args.use_cpu_initialization: self.bias = Parameter(torch.empty( self.output_size_per_partition, dtype=args.params_dtype)) else: Loading Loading @@ -311,7 +307,7 @@ class RowParallelLinear(torch.nn.Module): # we allocate the transpose. # Initialize weight. args = get_args() if _USE_CPU_INITIALIZATION: if args.use_cpu_initialization: self.weight = Parameter(torch.empty(self.output_size, self.input_size_per_partition, dtype=args.params_dtype)) Loading @@ -326,7 +322,7 @@ class RowParallelLinear(torch.nn.Module): _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=1, stride=stride) if bias: if _USE_CPU_INITIALIZATION: if args.use_cpu_initialization: self.bias = Parameter(torch.empty(self.output_size, dtype=args.params_dtype)) else: Loading Loading
megatron/arguments.py +5 −2 Original line number Diff line number Diff line Loading @@ -169,7 +169,7 @@ def _add_network_size_args(parser): help='Use OpenAIs GeLU implementation. This option' 'should not be used unless for backward compatibility' 'reasons.') group.add_argument('--onnx-safe', action='store_true', group.add_argument('--onnx-safe', type=bool, required=False, help='Use workarounds for known problems with Torch ONNX exporter') return parser Loading Loading @@ -335,8 +335,11 @@ def _add_distributed_args(parser): help='local rank passed from distributed launcher.') group.add_argument('--lazy-mpu-init', type=bool, required=False, help='If set to True, initialize_megatron() skips DDP initialization' ' and returns function to complete it instead' ' and returns function to complete it instead.' 'Also turns on --use-cpu-initialization flag.' 'This is for external DDP manager.' ) group.add_argument('--use-cpu-initialization', action='store_true', help='If set, affine parallel weights initialization uses CPU' ) return parser Loading
megatron/initialize.py +1 −0 Original line number Diff line number Diff line Loading @@ -62,6 +62,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, args = get_args() if args.lazy_mpu_init: args.use_cpu_initialization=True # delayed initialization of DDP-related stuff # We only set basic DDP globals set_model_parallel_world_size(args.model_parallel_size) Loading
megatron/mpu/layers.py +5 −9 Original line number Diff line number Diff line Loading @@ -47,10 +47,6 @@ from .utils import split_tensor_along_last_dim from .utils import VocabUtility from megatron import get_args _USE_CPU_INITIALIZATION = False def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1): """Initialize affine weight for model parallel on GPU.""" Loading Loading @@ -141,7 +137,7 @@ class VocabParallelEmbedding(torch.nn.Module): # Allocate weights and initialize. args = get_args() if _USE_CPU_INITIALIZATION: if args.use_cpu_initialization: self.weight = Parameter(torch.empty( self.num_embeddings_per_partition, self.embedding_dim, dtype=args.params_dtype)) Loading Loading @@ -217,7 +213,7 @@ class ColumnParallelLinear(torch.nn.Module): # we allocate the transpose. # Initialize weight. args = get_args() if _USE_CPU_INITIALIZATION: if args.use_cpu_initialization: self.weight = Parameter(torch.empty(self.output_size_per_partition, self.input_size, dtype=args.params_dtype)) Loading @@ -233,7 +229,7 @@ class ColumnParallelLinear(torch.nn.Module): partition_dim=0, stride=stride) if bias: if _USE_CPU_INITIALIZATION: if args.use_cpu_initialization: self.bias = Parameter(torch.empty( self.output_size_per_partition, dtype=args.params_dtype)) else: Loading Loading @@ -311,7 +307,7 @@ class RowParallelLinear(torch.nn.Module): # we allocate the transpose. # Initialize weight. args = get_args() if _USE_CPU_INITIALIZATION: if args.use_cpu_initialization: self.weight = Parameter(torch.empty(self.output_size, self.input_size_per_partition, dtype=args.params_dtype)) Loading @@ -326,7 +322,7 @@ class RowParallelLinear(torch.nn.Module): _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=1, stride=stride) if bias: if _USE_CPU_INITIALIZATION: if args.use_cpu_initialization: self.bias = Parameter(torch.empty(self.output_size, dtype=args.params_dtype)) else: Loading