Loading megatron/arguments.py +4 −1 Original line number Diff line number Diff line Loading @@ -322,7 +322,10 @@ def _add_distributed_args(parser): 'to use.') group.add_argument('--local_rank', type=int, default=None, help='local rank passed from distributed launcher.') group.add_argument('--lazy-mpu-init', type=bool, required=False, help='If set to True, initialize_megatron() skips DDP initialization' ' and returns function to complete it instead' 'This is for external DDP manager.' ) return parser Loading megatron/initialize.py +36 −16 Original line number Diff line number Diff line Loading @@ -25,8 +25,8 @@ from megatron import get_adlr_autoresume from megatron import get_args from megatron import get_tensorboard_writer from megatron import mpu from megatron.global_vars import set_global_variables from .global_vars import set_global_variables from .mpu import set_model_parallel_rank, set_model_parallel_world_size def initialize_megatron(extra_args_provider=None, args_defaults={}, ignore_unknown_args=False, allow_no_cuda=False): Loading @@ -34,7 +34,11 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, set autoresume and random seeds. `allow_no_cuda` should not be set unless using megatron for cpu only data processing. In general this arg should not be set unless you know what you are doing.""" what you are doing. Returns a function to finalize distributed env initialization (optionally, only for args.distributed_backend == "external_ddp") """ if not allow_no_cuda: # Make sure cuda is available. assert torch.cuda.is_available(), 'Megatron requires CUDA.' Loading @@ -45,20 +49,36 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, args_defaults=args_defaults, ignore_unknown_args=ignore_unknown_args) # torch.distributed initialization def ddp_init(): args = get_args() # Pytorch distributed. _initialize_distributed() # Autoresume. _init_autoresume() # Random seeds for reproducibility. args = get_args() if args.rank == 0: print('> setting random seeds to {} ...'.format(args.seed)) _set_random_seed(args.seed) args = get_args() if 'lazy_mpu_init' in args: # delayed initialization of DDP-related stuff # We only set basic DDP globals set_model_parallel_world_size(args.model_parallel_size) # and refurn function for external DDP manager to call when it has DDP initialized set_model_parallel_rank(args.rank) return ddp_init else: # Megatron's own DDP. Do initialization right away ddp_init() # Autoresume. _init_autoresume() # Write arguments to tensorboard. _write_args_to_tensorboard() # No continuation function return None def _initialize_distributed(): Loading megatron/model/__init__.py +2 −1 Original line number Diff line number Diff line Loading @@ -15,6 +15,7 @@ from .distributed import * from .bert_model import BertModel from megatron.model.realm_model import ICTBertModel from .realm_model import ICTBertModel from .gpt2_model import GPT2Model from .utils import get_params_for_weight_decay_optimization from .language_model import get_language_model megatron/model/language_model.py +8 −2 Original line number Diff line number Diff line Loading @@ -24,7 +24,7 @@ from megatron.module import MegatronModule from megatron.model.transformer import ParallelTransformer from megatron.model.utils import openai_gelu, erf_gelu from megatron.model.utils import get_linear_layer from megatron.model.utils import init_method_normal, scaled_init_method_normal def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): Loading @@ -44,7 +44,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, def get_language_model(attention_mask_func, num_tokentypes, add_pooler, init_method, scaled_init_method): init_method=None, scaled_init_method=None): """Build language model and return along with the key to save.""" args = get_args() Loading @@ -55,6 +55,12 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, elif args.onnx_safe: gelu = erf_gelu if init_method is None: init_method = init_method_normal(args.init_method_std) if scaled_init_method is None: scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers) # Language model. language_model = TransformerLanguageModel( attention_mask_func=attention_mask_func, Loading megatron/mpu/__init__.py +2 −2 Original line number Diff line number Diff line Loading @@ -27,9 +27,9 @@ from .initialize import get_data_parallel_group from .initialize import get_data_parallel_rank from .initialize import get_data_parallel_world_size from .initialize import get_model_parallel_group from .initialize import get_model_parallel_rank from .initialize import get_model_parallel_rank, set_model_parallel_rank from .initialize import get_model_parallel_src_rank from .initialize import get_model_parallel_world_size from .initialize import get_model_parallel_world_size, set_model_parallel_world_size from .initialize import initialize_model_parallel from .initialize import model_parallel_is_initialized Loading Loading
megatron/arguments.py +4 −1 Original line number Diff line number Diff line Loading @@ -322,7 +322,10 @@ def _add_distributed_args(parser): 'to use.') group.add_argument('--local_rank', type=int, default=None, help='local rank passed from distributed launcher.') group.add_argument('--lazy-mpu-init', type=bool, required=False, help='If set to True, initialize_megatron() skips DDP initialization' ' and returns function to complete it instead' 'This is for external DDP manager.' ) return parser Loading
megatron/initialize.py +36 −16 Original line number Diff line number Diff line Loading @@ -25,8 +25,8 @@ from megatron import get_adlr_autoresume from megatron import get_args from megatron import get_tensorboard_writer from megatron import mpu from megatron.global_vars import set_global_variables from .global_vars import set_global_variables from .mpu import set_model_parallel_rank, set_model_parallel_world_size def initialize_megatron(extra_args_provider=None, args_defaults={}, ignore_unknown_args=False, allow_no_cuda=False): Loading @@ -34,7 +34,11 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, set autoresume and random seeds. `allow_no_cuda` should not be set unless using megatron for cpu only data processing. In general this arg should not be set unless you know what you are doing.""" what you are doing. Returns a function to finalize distributed env initialization (optionally, only for args.distributed_backend == "external_ddp") """ if not allow_no_cuda: # Make sure cuda is available. assert torch.cuda.is_available(), 'Megatron requires CUDA.' Loading @@ -45,20 +49,36 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, args_defaults=args_defaults, ignore_unknown_args=ignore_unknown_args) # torch.distributed initialization def ddp_init(): args = get_args() # Pytorch distributed. _initialize_distributed() # Autoresume. _init_autoresume() # Random seeds for reproducibility. args = get_args() if args.rank == 0: print('> setting random seeds to {} ...'.format(args.seed)) _set_random_seed(args.seed) args = get_args() if 'lazy_mpu_init' in args: # delayed initialization of DDP-related stuff # We only set basic DDP globals set_model_parallel_world_size(args.model_parallel_size) # and refurn function for external DDP manager to call when it has DDP initialized set_model_parallel_rank(args.rank) return ddp_init else: # Megatron's own DDP. Do initialization right away ddp_init() # Autoresume. _init_autoresume() # Write arguments to tensorboard. _write_args_to_tensorboard() # No continuation function return None def _initialize_distributed(): Loading
megatron/model/__init__.py +2 −1 Original line number Diff line number Diff line Loading @@ -15,6 +15,7 @@ from .distributed import * from .bert_model import BertModel from megatron.model.realm_model import ICTBertModel from .realm_model import ICTBertModel from .gpt2_model import GPT2Model from .utils import get_params_for_weight_decay_optimization from .language_model import get_language_model
megatron/model/language_model.py +8 −2 Original line number Diff line number Diff line Loading @@ -24,7 +24,7 @@ from megatron.module import MegatronModule from megatron.model.transformer import ParallelTransformer from megatron.model.utils import openai_gelu, erf_gelu from megatron.model.utils import get_linear_layer from megatron.model.utils import init_method_normal, scaled_init_method_normal def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): Loading @@ -44,7 +44,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, def get_language_model(attention_mask_func, num_tokentypes, add_pooler, init_method, scaled_init_method): init_method=None, scaled_init_method=None): """Build language model and return along with the key to save.""" args = get_args() Loading @@ -55,6 +55,12 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, elif args.onnx_safe: gelu = erf_gelu if init_method is None: init_method = init_method_normal(args.init_method_std) if scaled_init_method is None: scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers) # Language model. language_model = TransformerLanguageModel( attention_mask_func=attention_mask_func, Loading
megatron/mpu/__init__.py +2 −2 Original line number Diff line number Diff line Loading @@ -27,9 +27,9 @@ from .initialize import get_data_parallel_group from .initialize import get_data_parallel_rank from .initialize import get_data_parallel_world_size from .initialize import get_model_parallel_group from .initialize import get_model_parallel_rank from .initialize import get_model_parallel_rank, set_model_parallel_rank from .initialize import get_model_parallel_src_rank from .initialize import get_model_parallel_world_size from .initialize import get_model_parallel_world_size, set_model_parallel_world_size from .initialize import initialize_model_parallel from .initialize import model_parallel_is_initialized Loading