Commit 28062e14 authored by mohammad's avatar mohammad
Browse files

moved entire optimizer build and tested

parent fb218c9d
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -33,7 +33,6 @@ from .distributed import *
from .bert_model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage
from .realm_model import ICTBertModel
from .gpt2_model import GPT2Model, GPT2ModelFirstStage, GPT2ModelIntermediateStage, GPT2ModelLastStage
from .utils import get_params_for_weight_decay_optimization
from .language_model import get_language_model

+0 −26
Original line number Diff line number Diff line
@@ -20,7 +20,6 @@ import math
import torch

from megatron import get_args
from megatron.model import import_layernorm

def init_method_normal(sigma):
    """Init method based on N(0, sigma)."""
@@ -60,28 +59,3 @@ def openai_gelu(x):
@torch.jit.script
def erf_gelu(x):
    return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype))

def get_params_for_weight_decay_optimization(module):
    """Divide params into with-weight-decay and without-weight-decay groups.
    Layernorms and baises will have no weight decay but the rest will.
    """

    args = get_args()
    LayerNorm = import_layernorm(args.fp32_residual_connection)
    
    weight_decay_params = {'params': []}
    no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
    for module_ in module.modules():
        if isinstance(module_, LayerNorm):
            no_weight_decay_params['params'].extend(
                [p for p in list(module_._parameters.values())
                 if p is not None])
        else:
            weight_decay_params['params'].extend(
                [p for n, p in list(module_._parameters.items())
                 if p is not None and n != 'bias'])
            no_weight_decay_params['params'].extend(
                [p for n, p in list(module_._parameters.items())
                 if p is not None and n == 'bias'])

    return weight_decay_params, no_weight_decay_params
+34 −1
Original line number Diff line number Diff line
@@ -21,16 +21,49 @@ from abc import abstractmethod
import torch

from apex.multi_tensor_apply import multi_tensor_applier
from apex.optimizers import FusedAdam as Adam
import amp_C

from megatron import get_args
from megatron import get_timers
from megatron import mpu
from megatron.model import import_layernorm


def get_megatron_optimizer(optimizer, model):
def get_params_for_weight_decay_optimization(module):
    """Divide params into with-weight-decay and without-weight-decay groups.
    Layernorms and baises will have no weight decay but the rest will.
    """

    args = get_args()
    LayerNorm = import_layernorm(args.fp32_residual_connection)
    
    weight_decay_params = {'params': []}
    no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
    for module_ in module.modules():
        if isinstance(module_, LayerNorm):
            no_weight_decay_params['params'].extend(
                [p for p in list(module_._parameters.values())
                 if p is not None])
        else:
            weight_decay_params['params'].extend(
                [p for n, p in list(module_._parameters.items())
                 if p is not None and n != 'bias'])
            no_weight_decay_params['params'].extend(
                [p for n, p in list(module_._parameters.items())
                 if p is not None and n == 'bias'])

    return weight_decay_params, no_weight_decay_params


def get_megatron_optimizer(model):

    args = get_args()

    # Base optimizer.
    param_groups = get_params_for_weight_decay_optimization(model)
    optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay,
        betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)

    if args.fp16:
        # Constant loss scale.
+10 −27
Original line number Diff line number Diff line
@@ -24,7 +24,6 @@ _TRAIN_START_TIME = time.time()

import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam

from megatron import get_args
from megatron import get_timers
@@ -45,7 +44,6 @@ from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination
from megatron.data.data_loaders import build_pretraining_data_loader
@@ -184,6 +182,10 @@ def get_model(model_provider_func):
    # Build model on cpu.
    model = model_provider_func()

    # Set tensor model parallel attributes if not set.
    for param in model.parameters():
        mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)

    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on (tensor, pipeline) '
@@ -212,30 +214,6 @@ def get_model(model_provider_func):
                              'Exiting.'.format(args.DDP_impl))


def get_optimizer(model):
    """Set up the optimizer."""
    args = get_args()

    # Build parameter groups (weight decay and non-decay).
    while isinstance(model, (torchDDP, LocalDDP, FP16_Module)):
        model = model.module
    param_groups = get_params_for_weight_decay_optimization(model)

    # Add model parallel attribute if it is not set.
    for param_group in param_groups:
        for param in param_group['params']:
            if not hasattr(param, 'tensor_model_parallel'):
                param.tensor_model_parallel = False

    # Use Adam.
    optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay,
        betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)

    # Wrap into fp16 optimizer.
    optimizer = get_megatron_optimizer(optimizer, model)
    return optimizer


def get_learning_rate_scheduler(optimizer):
    """Build the learning rate scheduler."""
    args = get_args()
@@ -284,7 +262,12 @@ def setup_model_and_optimizer(model_provider_func):
    args = get_args()

    model = get_model(model_provider_func)
    optimizer = get_optimizer(model)

    unwrapped_model = model
    while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16_Module)):
        unwrapped_model = unwrapped_model.module
    optimizer = get_megatron_optimizer(unwrapped_model)

    lr_scheduler = get_learning_rate_scheduler(optimizer)

    if args.load is not None: