Commit 42d2be06 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'onnx-safe' into 'main'

--onnx-safe made overrideable, -lazy_mpu_init act as _USE_CPU_INITIALIZATION

See merge request ADLR/megatron-lm!117
parents 3d1cbecf b6b48a32
Loading
Loading
Loading
Loading
+5 −2
Original line number Diff line number Diff line
@@ -169,7 +169,7 @@ def _add_network_size_args(parser):
                       help='Use OpenAIs GeLU implementation. This option'
                       'should not be used unless for backward compatibility'
                       'reasons.')
    group.add_argument('--onnx-safe', action='store_true',
    group.add_argument('--onnx-safe', type=bool, required=False,
                       help='Use workarounds for known problems with Torch ONNX exporter')

    return parser
@@ -335,8 +335,11 @@ def _add_distributed_args(parser):
                       help='local rank passed from distributed launcher.')
    group.add_argument('--lazy-mpu-init', type=bool, required=False,
                       help='If set to True, initialize_megatron() skips DDP initialization'
                       ' and returns function to complete it instead'
                       ' and returns function to complete it instead.'
                       'Also turns on --use-cpu-initialization flag.'
                       'This is for external DDP manager.' )
    group.add_argument('--use-cpu-initialization', action='store_true',
                       help='If set, affine parallel weights initialization uses CPU' )
    return parser


+1 −0
Original line number Diff line number Diff line
@@ -62,6 +62,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},

    args = get_args()
    if  args.lazy_mpu_init:
        args.use_cpu_initialization=True
        # delayed initialization of DDP-related stuff
        # We only set basic DDP globals    
        set_model_parallel_world_size(args.model_parallel_size)
+5 −9
Original line number Diff line number Diff line
@@ -47,10 +47,6 @@ from .utils import split_tensor_along_last_dim
from .utils import VocabUtility
from megatron import get_args


_USE_CPU_INITIALIZATION = False


def _initialize_affine_weight_gpu(weight, init_method,
                                  partition_dim, stride=1):
    """Initialize affine weight for model parallel on GPU."""
@@ -141,7 +137,7 @@ class VocabParallelEmbedding(torch.nn.Module):

        # Allocate weights and initialize.
        args = get_args()
        if _USE_CPU_INITIALIZATION:
        if args.use_cpu_initialization:
            self.weight = Parameter(torch.empty(
                self.num_embeddings_per_partition, self.embedding_dim,
                dtype=args.params_dtype))
@@ -217,7 +213,7 @@ class ColumnParallelLinear(torch.nn.Module):
        # we allocate the transpose.
        # Initialize weight.
        args = get_args()
        if _USE_CPU_INITIALIZATION:
        if args.use_cpu_initialization:
            self.weight = Parameter(torch.empty(self.output_size_per_partition,
                                                self.input_size,
                                                dtype=args.params_dtype))
@@ -233,7 +229,7 @@ class ColumnParallelLinear(torch.nn.Module):
                                          partition_dim=0, stride=stride)
            
        if bias:
            if _USE_CPU_INITIALIZATION:
            if args.use_cpu_initialization:
                self.bias = Parameter(torch.empty(
                    self.output_size_per_partition, dtype=args.params_dtype))
            else:
@@ -311,7 +307,7 @@ class RowParallelLinear(torch.nn.Module):
        # we allocate the transpose.
        # Initialize weight.
        args = get_args()
        if _USE_CPU_INITIALIZATION:
        if args.use_cpu_initialization:
            self.weight = Parameter(torch.empty(self.output_size,
                                                self.input_size_per_partition,
                                                dtype=args.params_dtype))
@@ -326,7 +322,7 @@ class RowParallelLinear(torch.nn.Module):
            _initialize_affine_weight_gpu(self.weight, init_method,
                                          partition_dim=1, stride=stride)
        if bias:
            if _USE_CPU_INITIALIZATION:
            if args.use_cpu_initialization:
                self.bias = Parameter(torch.empty(self.output_size,
                                                  dtype=args.params_dtype))
            else: