Commit c632d205 authored by Boris Fomitchev's avatar Boris Fomitchev
Browse files

Addressing code review comments

parent be842037
Loading
Loading
Loading
Loading
+7 −7
Original line number Diff line number Diff line
@@ -25,8 +25,8 @@ from megatron import get_adlr_autoresume
from megatron import get_args
from megatron import get_tensorboard_writer
from megatron import mpu
from .global_vars import set_global_variables
from .mpu import set_model_parallel_rank, set_model_parallel_world_size
from megatron.global_vars import set_global_variables
from megatron.mpu import set_model_parallel_rank, set_model_parallel_world_size

def initialize_megatron(extra_args_provider=None, args_defaults={},
                        ignore_unknown_args=False, allow_no_cuda=False):
@@ -50,7 +50,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
                         ignore_unknown_args=ignore_unknown_args)

    # torch.distributed initialization
    def ddp_init():
    def finish_mpu_init():
        args = get_args()
        # Pytorch distributed.
        _initialize_distributed()
@@ -61,16 +61,16 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
        _set_random_seed(args.seed)

    args = get_args()
    if 'lazy_mpu_init' in args: 
    if  args.lazy_mpu_init:
        # delayed initialization of DDP-related stuff
        # We only set basic DDP globals    
        set_model_parallel_world_size(args.model_parallel_size)
        # and refurn function for external DDP manager to call when it has DDP initialized
        set_model_parallel_rank(args.rank)    
        return ddp_init
        return finish_mpu_init
    else:
        # Megatron's own DDP. Do initialization right away
        ddp_init()
        # Megatron's MPU is the master. Complete initialization right away.
        finish_mpu_init()
        
        # Autoresume.
        _init_autoresume()