Commit fb218c9d authored by mohammad's avatar mohammad
Browse files

megatron optimizer tested, before working on clip grad

parent 2eaa3ccc
Loading
Loading
Loading
Loading
+20 −3
Original line number Diff line number Diff line

# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron optimizer."""

from abc import ABC
from abc import abstractmethod
@@ -70,6 +85,8 @@ class MegatronGradScaler(ABC):


class ConstantGradScaler(MegatronGradScaler):

    def update(self, found_inf):
        pass


+1 −48
Original line number Diff line number Diff line
@@ -39,7 +39,6 @@ from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.fp16 import FP16_Module
#from megatron.fp16 import FP16_Optimizer
from megatron.optimizer.optimizer import get_megatron_optimizer

from megatron.initialize import initialize_megatron
@@ -234,15 +233,6 @@ def get_optimizer(model):

    # Wrap into fp16 optimizer.
    optimizer = get_megatron_optimizer(optimizer, model)
    '''
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
                                       'min_scale': args.min_scale,
                                       'delayed_shift': args.hysteresis})
    '''
    return optimizer


@@ -373,13 +363,7 @@ def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_g
    if output_tensor_grad is None:
        output_tensor = optimizer.scale_loss(output_tensor)
    torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
    '''
    if args.fp16 and output_tensor_grad is None:
        optimizer.backward(output_tensor, update_master_grads=False,
                           output_tensor_grad=output_tensor_grad)
    else:
        torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
    '''

    # Collect the grad of the input_tensor.
    input_tensor_grad = None
    if input_tensor is not None:
@@ -598,12 +582,6 @@ def train_step(forward_step_func, data_iterator,

    # Set grad to zero.
    optimizer.zero_grad()
    '''
    if args.fp16:
        optimizer.zero_grad(set_grads_to_None=True)
    else:
        optimizer.zero_grad()
    '''

    if mpu.get_pipeline_model_parallel_world_size() > 1:
        losses_reduced = forward_backward_pipelining(
@@ -636,31 +614,6 @@ def train_step(forward_step_func, data_iterator,
                                         group=mpu.get_embedding_group())
    timers('backward-embedding-all-reduce').stop()

    # Update master gradients.
    '''
    timers('backward-master-grad').start()
    if args.fp16:
        optimizer.update_master_grads()
    timers('backward-master-grad').stop()
    '''
    # Clipping gradients helps prevent the exploding gradient.
    '''
    timers('backward-clip-grad').start()
    if args.clip_grad > 0.:
        if not args.fp16:
            named_parameters = model.named_parameters()
            parameters = []
            parameter_names = []
            for parameter_name, parameter in model.named_parameters():
                parameters.append(parameter)
                parameter_names.append(parameter_name)
            mpu.clip_grad_norm(parameters, args.clip_grad,
                               parameter_names=parameter_names)
        else:
            optimizer.clip_master_grads(args.clip_grad)
    timers('backward-clip-grad').stop()
    '''

    # Update parameters.
    timers('optimizer').start()
    update_successfull = optimizer.step()