Commit 0a74547c authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'parallel-nemo' into 'main'

Changes for NeMo/lightning compatibility

See merge request ADLR/megatron-lm!104
parents bdc52c19 4e8ace9b
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -30,7 +30,7 @@ from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume
from .global_vars import get_timers

from .initialize  import initialize_megatron

def print_rank_0(message):
    """If distributed is initialized print only on rank 0."""
+4 −1
Original line number Diff line number Diff line
@@ -322,7 +322,10 @@ def _add_distributed_args(parser):
                       'to use.')
    group.add_argument('--local_rank', type=int, default=None,
                       help='local rank passed from distributed launcher.')

    group.add_argument('--lazy-mpu-init', type=bool, required=False,
                       help='If set to True, initialize_megatron() skips DDP initialization'
                       ' and returns function to complete it instead'
                       'This is for external DDP manager.' )
    return parser


+35 −26
Original line number Diff line number Diff line
@@ -26,7 +26,7 @@ from megatron import get_args
from megatron import get_tensorboard_writer
from megatron import mpu
from megatron.global_vars import set_global_variables

from megatron.mpu import set_model_parallel_rank, set_model_parallel_world_size

def initialize_megatron(extra_args_provider=None, args_defaults={},
                        ignore_unknown_args=False, allow_no_cuda=False):
@@ -34,37 +34,51 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
    set autoresume and random seeds.
    `allow_no_cuda` should not be set unless using megatron for cpu only 
    data processing. In general this arg should not be set unless you know 
    what you are doing."""
    what you are doing.
    Returns a function to finalize distributed env initialization 
    (optionally, only when args.lazy_mpu_init == True)

"""
    if not allow_no_cuda:
        # Make sure cuda is available.
        assert torch.cuda.is_available(), 'Megatron requires CUDA.'

    # This is temporary WAR to make simple case like pytest calling with same args twice
    # Need to implement clean factory init.
    if mpu.model_parallel_is_initialized():
        return
    
    
    # Parse args, build tokenizer, and set adlr-autoresume,
    # tensorboard-writer, and timers.
    set_global_variables(extra_args_provider=extra_args_provider,
                         args_defaults=args_defaults,
                         ignore_unknown_args=ignore_unknown_args)

    # torch.distributed initialization
    def finish_mpu_init():
        args = get_args()
        # Pytorch distributed.
        _initialize_distributed()
        
    # Autoresume.
    _init_autoresume()

        # Random seeds for reproducibility.
    args = get_args()
        if args.rank == 0:
            print('> setting random seeds to {} ...'.format(args.seed))
        _set_random_seed(args.seed)

    args = get_args()
    if  args.lazy_mpu_init:
        # delayed initialization of DDP-related stuff
        # We only set basic DDP globals    
        set_model_parallel_world_size(args.model_parallel_size)
        # and return function for external DDP manager to call when it has DDP initialized
        set_model_parallel_rank(args.rank)    
        return finish_mpu_init
    else:
        # Megatron's MPU is the master. Complete initialization right away.
        finish_mpu_init()
        
        # Autoresume.
        _init_autoresume()
        
        # Write arguments to tensorboard.
        _write_args_to_tensorboard()
        # No continuation function
        return None
        

def _initialize_distributed():
@@ -79,11 +93,6 @@ def _initialize_distributed():
                  'skipping initialization ...', flush=True)
        args.rank = torch.distributed.get_rank()
        args.world_size = torch.distributed.get_world_size()
        if device_count > 0:
            device = torch.cuda.current_device()
            local_rank = args.rank % device_count
            assert local_rank == device, \
                'expected local-rank to be the same as rank % device-count.'

    else:

+2 −1
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@

from .distributed import *
from .bert_model import BertModel
from megatron.model.realm_model import ICTBertModel
from .realm_model import ICTBertModel
from .gpt2_model import GPT2Model
from .utils import get_params_for_weight_decay_optimization
from .language_model import get_language_model
+8 −2
Original line number Diff line number Diff line
@@ -24,7 +24,7 @@ from megatron.module import MegatronModule
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer

from megatron.model.utils import init_method_normal, scaled_init_method_normal

def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
                       bias=None):
@@ -44,7 +44,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,


def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
                       init_method, scaled_init_method):
                       init_method=None, scaled_init_method=None):
    """Build language model and return along with the key to save."""
    args = get_args()

@@ -55,6 +55,12 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
    elif args.onnx_safe:
        gelu = erf_gelu
    
    if init_method is None:
        init_method = init_method_normal(args.init_method_std)

    if scaled_init_method is None:
        scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)

    # Language model.
    language_model = TransformerLanguageModel(
        attention_mask_func=attention_mask_func,
Loading