Loading megatron/fp16/fp16.py +9 −4 Original line number Diff line number Diff line Loading @@ -22,6 +22,8 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from .loss_scaler import DynamicLossScaler, LossScaler from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm from apex.multi_tensor_apply import multi_tensor_applier import amp_C from megatron.module import MegatronModule Loading Loading @@ -320,9 +322,12 @@ class FP16_Optimizer(object): def _downscale_master(self): if self.loss_scale != 1.0: for group in self.optimizer.param_groups: for param in group['params']: if param.grad is not None: param.grad.data.mul_(1. / self.loss_scale) grads = [p.grad for p in group['params'] if p.grad is not None] _overflow_buf = torch.cuda.IntTensor([0]) multi_tensor_applier(amp_C.multi_tensor_scale, _overflow_buf, [grads, grads], 1./self.loss_scale) def clip_master_grads(self, max_norm, norm_type=2): """ Loading megatron/fp16/fp16util.py +10 −0 Original line number Diff line number Diff line Loading @@ -18,6 +18,9 @@ import torch.nn as nn from torch.autograd import Variable from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from apex.multi_tensor_apply import multi_tensor_applier import amp_C from megatron import mpu Loading Loading @@ -169,6 +172,13 @@ def model_grads_to_master_grads(model_params, master_params, flat_master=False): master.grad.data.copy_(model.grad.data) else: master.grad = None model_grads = [p.grad for p in model_params if p.grad is not None] master_grads = [p.grad for p in master_params if p.grad is not None] _overflow_buf = torch.cuda.IntTensor([0]) multi_tensor_applier(amp_C.multi_tensor_scale, _overflow_buf, [model_grads, master_grads], 1.0) def master_params_to_model_params(model_params, master_params, flat_master=False): Loading megatron/fp16/loss_scaler.py +12 −2 Original line number Diff line number Diff line Loading @@ -57,7 +57,12 @@ class LossScaler: return self.cur_scale def scale_gradient(self, module, grad_in, grad_out): return tuple(self.loss_scale * g for g in grad_in) _overflow_buf = torch.cuda.IntTensor([0]) multi_tensor_applier(amp_C.multi_tensor_scale, _overflow_buf, [grad_in, grad_in], self.loss_scale) return grad_in def backward(self, loss, retain_graph=False): scaled_loss = loss * self.loss_scale Loading Loading @@ -180,7 +185,12 @@ class DynamicLossScaler: return self.cur_scale def scale_gradient(self, module, grad_in, grad_out): return tuple(self.loss_scale * g for g in grad_in) _overflow_buf = torch.cuda.IntTensor([0]) multi_tensor_applier(amp_C.multi_tensor_scale, _overflow_buf, [grad_in, grad_in], self.loss_scale) return grad_in def backward(self, loss, retain_graph=False): scaled_loss = loss * self.loss_scale Loading Loading
megatron/fp16/fp16.py +9 −4 Original line number Diff line number Diff line Loading @@ -22,6 +22,8 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from .loss_scaler import DynamicLossScaler, LossScaler from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm from apex.multi_tensor_apply import multi_tensor_applier import amp_C from megatron.module import MegatronModule Loading Loading @@ -320,9 +322,12 @@ class FP16_Optimizer(object): def _downscale_master(self): if self.loss_scale != 1.0: for group in self.optimizer.param_groups: for param in group['params']: if param.grad is not None: param.grad.data.mul_(1. / self.loss_scale) grads = [p.grad for p in group['params'] if p.grad is not None] _overflow_buf = torch.cuda.IntTensor([0]) multi_tensor_applier(amp_C.multi_tensor_scale, _overflow_buf, [grads, grads], 1./self.loss_scale) def clip_master_grads(self, max_norm, norm_type=2): """ Loading
megatron/fp16/fp16util.py +10 −0 Original line number Diff line number Diff line Loading @@ -18,6 +18,9 @@ import torch.nn as nn from torch.autograd import Variable from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from apex.multi_tensor_apply import multi_tensor_applier import amp_C from megatron import mpu Loading Loading @@ -169,6 +172,13 @@ def model_grads_to_master_grads(model_params, master_params, flat_master=False): master.grad.data.copy_(model.grad.data) else: master.grad = None model_grads = [p.grad for p in model_params if p.grad is not None] master_grads = [p.grad for p in master_params if p.grad is not None] _overflow_buf = torch.cuda.IntTensor([0]) multi_tensor_applier(amp_C.multi_tensor_scale, _overflow_buf, [model_grads, master_grads], 1.0) def master_params_to_model_params(model_params, master_params, flat_master=False): Loading
megatron/fp16/loss_scaler.py +12 −2 Original line number Diff line number Diff line Loading @@ -57,7 +57,12 @@ class LossScaler: return self.cur_scale def scale_gradient(self, module, grad_in, grad_out): return tuple(self.loss_scale * g for g in grad_in) _overflow_buf = torch.cuda.IntTensor([0]) multi_tensor_applier(amp_C.multi_tensor_scale, _overflow_buf, [grad_in, grad_in], self.loss_scale) return grad_in def backward(self, loss, retain_graph=False): scaled_loss = loss * self.loss_scale Loading Loading @@ -180,7 +185,12 @@ class DynamicLossScaler: return self.cur_scale def scale_gradient(self, module, grad_in, grad_out): return tuple(self.loss_scale * g for g in grad_in) _overflow_buf = torch.cuda.IntTensor([0]) multi_tensor_applier(amp_C.multi_tensor_scale, _overflow_buf, [grad_in, grad_in], self.loss_scale) return grad_in def backward(self, loss, retain_graph=False): scaled_loss = loss * self.loss_scale Loading