Commit 97e07d64 authored by Mohammad's avatar Mohammad
Browse files

Merge branch 'master' into remove_local_ddp_bcast

parents 41c1af0e eb0a8bf0
Loading
Loading
Loading
Loading
+6 −2
Original line number Diff line number Diff line
@@ -19,7 +19,8 @@ import argparse
import os


def parse_args(extra_args_provider=None, defaults={}):
def parse_args(extra_args_provider=None, defaults={},
               ignore_unknown_args=False):
    """Parse all arguments."""
    parser = argparse.ArgumentParser(description='Megatron-LM Arguments')

@@ -41,6 +42,9 @@ def parse_args(extra_args_provider=None, defaults={}):
        parser = extra_args_provider(parser)

    # Parse.
    if ignore_unknown_args:
        args, _ = parser.parse_known_args()
    else:
        args = parser.parse_args()

    # Set input defaults.
+8 −4
Original line number Diff line number Diff line
@@ -61,22 +61,26 @@ def get_timers():
    return _GLOBAL_TIMERS


def set_global_variables(extra_args_provider=None, args_defaults={}):
def set_global_variables(extra_args_provider=None, args_defaults={},
                         ignore_unknown_args=False):
    """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
    args = _parse_args(extra_args_provider=extra_args_provider,
                       defaults=args_defaults)
                       defaults=args_defaults,
                       ignore_unknown_args=ignore_unknown_args)
    _ = _build_tokenizer(args)
    _set_tensorboard_writer(args)
    _set_adlr_autoresume(args)
    _set_timers()


def _parse_args(extra_args_provider=None, defaults={}):
def _parse_args(extra_args_provider=None, defaults={},
                ignore_unknown_args=False):
    """Parse entire arguments."""
    global _GLOBAL_ARGS
    _ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
    _GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider,
                              defaults=defaults)
                              defaults=defaults,
                              ignore_unknown_args=ignore_unknown_args)
    return _GLOBAL_ARGS


+4 −2
Original line number Diff line number Diff line
@@ -28,7 +28,8 @@ from megatron import mpu
from megatron.global_vars import set_global_variables


def initialize_megatron(extra_args_provider=None, args_defaults={}):
def initialize_megatron(extra_args_provider=None, args_defaults={},
                        ignore_unknown_args=False):
    """Set global variables, initialize distributed, and
    set autoresume and random seeds."""
    # Make sure cuda is available.
@@ -37,7 +38,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}):
    # Parse args, build tokenizer, and set adlr-autoresume,
    # tensorboard-writer, and timers.
    set_global_variables(extra_args_provider=extra_args_provider,
                         args_defaults=args_defaults)
                         args_defaults=args_defaults,
                         ignore_unknown_args=ignore_unknown_args)

    # Pytorch distributed.
    _initialize_distributed()