Commit 2eaa3ccc authored by mohammad's avatar mohammad
Browse files

fp32 is also working

parent c6a58e41
Loading
Loading
Loading
Loading
+12 −16
Original line number Diff line number Diff line
@@ -105,11 +105,6 @@ def parse_args(extra_args_provider=None, defaults={},
                args.global_batch_size), flush=True)
    assert args.global_batch_size > 0

    # Fp16 loss scaling.
    args.dynamic_loss_scale = False
    if args.loss_scale is None:
        args.dynamic_loss_scale = True

    # Parameters dtype.
    args.params_dtype = torch.float
    if args.fp16:
@@ -442,6 +437,18 @@ def _add_mixed_precision_args(parser):

    group.add_argument('--fp16', action='store_true',
                       help='Run model in fp16 mode.')
    group.add_argument('--loss-scale', type=float, default=None,
                       help='Static loss scaling, positive power of 2 '
                       'values can improve fp16 convergence. If None, dynamic'
                       'loss scaling is used.')
    group.add_argument('--initial-loss-scale', type=float, default=2**32,
                       help='Initial loss-scale for dynamic loss scaling.')
    group.add_argument('--min-loss-scale', type=float, default=1.0,
                       help='Minimum loss scale for dynamic loss scale.')
    group.add_argument('--loss-scale-window', type=float, default=1000,
                       help='Window over which to raise/lower dynamic scale.')
    group.add_argument('--hysteresis', type=int, default=2,
                       help='hysteresis for dynamic loss scaling')
    group.add_argument('--fp32-residual-connection', action='store_true',
                       help='Move residual connections to fp32.')
    group.add_argument('--apply-query-key-layer-scaling', action='store_true',
@@ -452,21 +459,10 @@ def _add_mixed_precision_args(parser):
                       help='Run attention masking and softmax in fp32.')
    group.add_argument('--fp32-allreduce', action='store_true',
                       help='All-reduce in fp32')
    group.add_argument('--hysteresis', type=int, default=2,
                       help='hysteresis for dynamic loss scaling')
    group.add_argument('--loss-scale', type=float, default=None,
                       help='Static loss scaling, positive power of 2 '
                       'values can improve fp16 convergence. If None, dynamic'
                       'loss scaling is used.')
    group.add_argument('--loss-scale-window', type=float, default=1000,
                       help='Window over which to raise/lower dynamic scale.')
    group.add_argument('--min-scale', type=float, default=1,
                       help='Minimum loss scale for dynamic loss scale.')
    group.add_argument('--fp16-lm-cross-entropy', action='store_true',
                       help='Move the cross entropy unreduced loss calculation'
                       'for lm head to fp16.')


    return parser


+3 −1
Original line number Diff line number Diff line
@@ -44,6 +44,8 @@ from .initialize import model_parallel_is_initialized
from .layers import ColumnParallelLinear
from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding
from .layers import (set_defaults_if_not_set_tensor_model_parallel_attributes,
                     copy_tensor_model_parallel_attributes)
                     
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
+42 −7
Original line number Diff line number Diff line
@@ -37,13 +37,47 @@ from .utils import split_tensor_along_last_dim
from .utils import VocabUtility
from megatron import get_args


_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
                                      'partition_dim': -1,
                                      'partition_stride': 1}


def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
    # Make sure the attributes are not set.
    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
        assert not hasattr(tensor, attribute)
    # Set the attributes.
    setattr(tensor, 'tensor_model_parallel', is_parallel)
    setattr(tensor, 'partition_dim', dim)
    setattr(tensor, 'partition_stride', stride)


def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
    def maybe_set(attribute, value):
        if not hasattr(tensor, attribute):
            setattr(tensor, attribute, value)
    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
        maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])


def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
    def maybe_copy(attribute):
        if hasattr(source_tensor, attribute):
            setattr(destination_tensor, attribute,
                    getattr(source_tensor, attribute))
    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
        maybe_copy(attribute)


def _initialize_affine_weight_gpu(weight, init_method,
                                  partition_dim, stride=1):
    """Initialize affine weight for model parallel on GPU."""

    weight.tensor_model_parallel = True
    weight.partition_dim = partition_dim
    weight.partition_stride = stride
    set_tensor_model_parallel_attributes(tensor=weight,
                                         is_parallel=True,
                                         dim=partition_dim,
                                         stride=stride)

    with get_cuda_rng_tracker().fork():
        init_method(weight)
@@ -58,9 +92,10 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
    Build the master weight on all processes and scatter
    the relevant chunk."""

    weight.tensor_model_parallel = True
    weight.partition_dim = partition_dim
    weight.partition_stride = stride
    set_tensor_model_parallel_attributes(tensor=weight,
                                         is_parallel=True,
                                         dim=partition_dim,
                                         stride=stride)

    # Initialize master weight
    master_weight = torch.empty(output_size, input_size,
+84 −16
Original line number Diff line number Diff line
@@ -8,26 +8,34 @@ import torch
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C

from megatron import mpu
from megatron import get_args
from megatron import get_timers
from megatron import mpu


def get_megatron_optimizer(optimizer):
def get_megatron_optimizer(optimizer, model):

    args = get_args()

    if args.fp16:
        # Constant loss scale.
        if args.loss_scale:
            grad_scaler = ConstantGradScaler(args.loss_scale)
        # Dynamic loss scale.
        else:        
            grad_scaler = DynamicGradScaler(
        initial_scale=2**32,
        min_scale=args.min_scale,
                initial_scale=args.initial_loss_scale,
                min_scale=args.min_loss_scale,
                growth_factor=2.0,
                backoff_factor=0.5,
                growth_interval=args.loss_scale_window,
                hysteresis=args.hysteresis)
        # Megatron optimizer.
        return FP16OptimizerWithFP16Params(optimizer, grad_scaler,
                                           args.clip_grad)

    megatron_optimizer = FP16OptimizerWithFP16Params(
        optimizer, grad_scaler, args.clip_grad)

    return megatron_optimizer
    # FP32.
    return FP32Optimizer(optimizer, model, args.clip_grad)



@@ -239,9 +247,8 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
                        # Store grads
                        master_param.requires_grad = True
                        # Copy tensor model parallel attributes.
                        master_param.tensor_model_parallel = param.tensor_model_parallel
                        #mpu.copy_tensor_model_parallel_attributes(master_param,
                        #                                          param)
                        mpu.copy_tensor_model_parallel_attributes(master_param,
                                                                  param)
                        # Replace the optimizer params with the new fp32 copy.
                        param_group['params'][i] = master_param
                        fp32_from_fp16_params_this_group.append(master_param)
@@ -286,10 +293,13 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
    @torch.no_grad()
    def step(self):

        timers = get_timers()

        # ==================================================
        # Copy gradients from model params to master params.
        # ==================================================

        timers('optimizer-copy-to-master-grad').start()
        # This only needs to be done for the fp16 group.
        model_grads = []
        master_grads = []
@@ -307,11 +317,13 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
                             self._dummy_overflow_buf,
                             [model_grads, master_grads],
                             1.0)
        timers('optimizer-copy-to-master-grad').stop()

        # ==============================
        # Unscale and check for inf/nan.
        # ==============================

        timers('optimizer-unscale-and-check-inf').start()
        # Append fp32 parameters.
        for master_group in self.fp32_from_fp32_groups:
            for master_param in master_group:
@@ -326,6 +338,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
        torch.distributed.all_reduce(self.found_inf,
                                     op=torch.distributed.ReduceOp.MAX,
                                     group=mpu.get_model_parallel_group())
        timers('optimizer-unscale-and-check-inf').stop()

        # ==================================
        # We are done with scaling gradients
@@ -344,11 +357,13 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
        # Clip the master gradients.
        # ==========================

        timers('optimizer-clip-master-grad').start()
        fp32_params = []
        for param_group in self.optimizer.param_groups:
            for param in param_group['params']:
                fp32_params.append(param)
        mpu.clip_grad_norm(fp32_params, self.clip_grad)
        timers('optimizer-clip-master-grad').stop()

        # ===================
        # Step the optimizer.
@@ -360,6 +375,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
        # Update params from master params.
        # =================================

        timers('optimizer-copy-master-to-model-params').start()
        # Only needed for the fp16 params.
        model_data = []
        master_data = []
@@ -374,5 +390,57 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
                             self._dummy_overflow_buf,
                             [master_data, model_data],
                             1.0)
        timers('optimizer-copy-master-to-model-params').stop()

        return True


class FP32Optimizer(MegatronOptimizer):

    def __init__(self, optimizer, model, clip_grad):

        super(FP32Optimizer, self).__init__(optimizer)
        self.model = model
        self.clip_grad = clip_grad
        self._scale = torch.cuda.FloatTensor([1.0])


    def zero_grad(self, set_to_none=True):
        """Copied from torch.optim.optimizer"""
        for group in self.optimizer.param_groups:
            _zero_grad_group_helper(group['params'], set_to_none)


    def get_loss_scale(self):
        """FP32 optimizer does not do any scaling."""
        return self._scale


    @torch.no_grad()
    def step(self):
        """Clip gradients (if needed) and step the base optimizer.
        Always return auccessful since there is no overflow."""

        # Clip gradients.
        if self.clip_grad > 0.0:
            parameters = []
            parameter_names = []
            for parameter_name, parameter in self.model.named_parameters():
                parameters.append(parameter)
                parameter_names.append(parameter_name)
            mpu.clip_grad_norm(parameters, self.clip_grad,
                               parameter_names=parameter_names)

        # Update parameters.
        self.optimizer.step()

        # No overflow for FP32 optimizer.
        return True


    def state_dict(self):
        return self.optimizer.state_dict()


    def load_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict)
+7 −6
Original line number Diff line number Diff line
@@ -233,8 +233,7 @@ def get_optimizer(model):
        betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)

    # Wrap into fp16 optimizer.
    if args.fp16:
        optimizer = get_megatron_optimizer(optimizer)
    optimizer = get_megatron_optimizer(optimizer, model)
    '''
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
@@ -737,10 +736,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
    add_to_logging('backward-send-forward-recv')
    add_to_logging('backward-master-grad')
    add_to_logging('backward-params-all-reduce')
    add_to_logging('backward-embedding-all-reduce')
    add_to_logging('backward-clip-grad')
    add_to_logging('optimizer-copy-to-master-grad')
    add_to_logging('optimizer-unscale-and-check-inf')
    add_to_logging('optimizer-clip-master-grad')
    add_to_logging('optimizer-copy-master-to-model-params')
    add_to_logging('optimizer')
    add_to_logging('batch-generator')