Loading megatron/fp16/fp16.py +9 −4 Original line number Diff line number Diff line Loading @@ -22,6 +22,8 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from .loss_scaler import DynamicLossScaler, LossScaler from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm from apex.multi_tensor_apply import multi_tensor_applier import amp_C from megatron.module import MegatronModule Loading Loading @@ -320,9 +322,12 @@ class FP16_Optimizer(object): def _downscale_master(self): if self.loss_scale != 1.0: for group in self.optimizer.param_groups: for param in group['params']: if param.grad is not None: param.grad.data.mul_(1. / self.loss_scale) grads = [p.grad for p in group['params'] if p.grad is not None] _overflow_buf = torch.cuda.IntTensor([0]) multi_tensor_applier(amp_C.multi_tensor_scale, _overflow_buf, [grads, grads], 1./self.loss_scale) def clip_master_grads(self, max_norm, norm_type=2): """ Loading megatron/fp16/fp16util.py +10 −1 Original line number Diff line number Diff line Loading @@ -18,6 +18,9 @@ import torch.nn as nn from torch.autograd import Variable from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from apex.multi_tensor_apply import multi_tensor_applier import amp_C from megatron import mpu Loading Loading @@ -166,9 +169,15 @@ def model_grads_to_master_grads(model_params, master_params, flat_master=False): if model.grad is not None: if master.grad is None: master.grad = Variable(master.data.new(*master.data.size())) master.grad.data.copy_(model.grad.data) else: master.grad = None model_grads = [p.grad for p in model_params if p.grad is not None] master_grads = [p.grad for p in master_params if p.grad is not None] _overflow_buf = torch.cuda.IntTensor([0]) multi_tensor_applier(amp_C.multi_tensor_scale, _overflow_buf, [model_grads, master_grads], 1.0) def master_params_to_model_params(model_params, master_params, flat_master=False): Loading megatron/fp16/loss_scaler.py +16 −2 Original line number Diff line number Diff line Loading @@ -14,6 +14,10 @@ # limitations under the License. import torch from apex.multi_tensor_apply import multi_tensor_applier import amp_C from megatron import mpu # item() is a recent addition, so this helps with backward compatibility. Loading Loading @@ -57,7 +61,12 @@ class LossScaler: return self.cur_scale def scale_gradient(self, module, grad_in, grad_out): return tuple(self.loss_scale * g for g in grad_in) _overflow_buf = torch.cuda.IntTensor([0]) multi_tensor_applier(amp_C.multi_tensor_scale, _overflow_buf, [grad_in, grad_in], self.loss_scale) return grad_in def backward(self, loss, retain_graph=False): scaled_loss = loss * self.loss_scale Loading Loading @@ -180,7 +189,12 @@ class DynamicLossScaler: return self.cur_scale def scale_gradient(self, module, grad_in, grad_out): return tuple(self.loss_scale * g for g in grad_in) _overflow_buf = torch.cuda.IntTensor([0]) multi_tensor_applier(amp_C.multi_tensor_scale, _overflow_buf, [grad_in, grad_in], self.loss_scale) return grad_in def backward(self, loss, retain_graph=False): scaled_loss = loss * self.loss_scale Loading megatron/model/gpt2_model.py +1 −2 Original line number Diff line number Diff line Loading @@ -27,8 +27,7 @@ from .utils import scaled_init_method_normal def gpt2_attention_mask_func(attention_scores, ltor_mask): attention_scores = torch.mul(attention_scores, ltor_mask) - \ 10000.0 * (1.0 - ltor_mask) attention_scores.masked_fill_(ltor_mask, -10000.0) return attention_scores Loading megatron/text_generation_utils.py +1 −2 Original line number Diff line number Diff line Loading @@ -42,8 +42,7 @@ def get_batch(context_tokens): tokenizer.eod, args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss, args.fp16) args.eod_mask_loss) return tokens, attention_mask, position_ids Loading Loading
megatron/fp16/fp16.py +9 −4 Original line number Diff line number Diff line Loading @@ -22,6 +22,8 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from .loss_scaler import DynamicLossScaler, LossScaler from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm from apex.multi_tensor_apply import multi_tensor_applier import amp_C from megatron.module import MegatronModule Loading Loading @@ -320,9 +322,12 @@ class FP16_Optimizer(object): def _downscale_master(self): if self.loss_scale != 1.0: for group in self.optimizer.param_groups: for param in group['params']: if param.grad is not None: param.grad.data.mul_(1. / self.loss_scale) grads = [p.grad for p in group['params'] if p.grad is not None] _overflow_buf = torch.cuda.IntTensor([0]) multi_tensor_applier(amp_C.multi_tensor_scale, _overflow_buf, [grads, grads], 1./self.loss_scale) def clip_master_grads(self, max_norm, norm_type=2): """ Loading
megatron/fp16/fp16util.py +10 −1 Original line number Diff line number Diff line Loading @@ -18,6 +18,9 @@ import torch.nn as nn from torch.autograd import Variable from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from apex.multi_tensor_apply import multi_tensor_applier import amp_C from megatron import mpu Loading Loading @@ -166,9 +169,15 @@ def model_grads_to_master_grads(model_params, master_params, flat_master=False): if model.grad is not None: if master.grad is None: master.grad = Variable(master.data.new(*master.data.size())) master.grad.data.copy_(model.grad.data) else: master.grad = None model_grads = [p.grad for p in model_params if p.grad is not None] master_grads = [p.grad for p in master_params if p.grad is not None] _overflow_buf = torch.cuda.IntTensor([0]) multi_tensor_applier(amp_C.multi_tensor_scale, _overflow_buf, [model_grads, master_grads], 1.0) def master_params_to_model_params(model_params, master_params, flat_master=False): Loading
megatron/fp16/loss_scaler.py +16 −2 Original line number Diff line number Diff line Loading @@ -14,6 +14,10 @@ # limitations under the License. import torch from apex.multi_tensor_apply import multi_tensor_applier import amp_C from megatron import mpu # item() is a recent addition, so this helps with backward compatibility. Loading Loading @@ -57,7 +61,12 @@ class LossScaler: return self.cur_scale def scale_gradient(self, module, grad_in, grad_out): return tuple(self.loss_scale * g for g in grad_in) _overflow_buf = torch.cuda.IntTensor([0]) multi_tensor_applier(amp_C.multi_tensor_scale, _overflow_buf, [grad_in, grad_in], self.loss_scale) return grad_in def backward(self, loss, retain_graph=False): scaled_loss = loss * self.loss_scale Loading Loading @@ -180,7 +189,12 @@ class DynamicLossScaler: return self.cur_scale def scale_gradient(self, module, grad_in, grad_out): return tuple(self.loss_scale * g for g in grad_in) _overflow_buf = torch.cuda.IntTensor([0]) multi_tensor_applier(amp_C.multi_tensor_scale, _overflow_buf, [grad_in, grad_in], self.loss_scale) return grad_in def backward(self, loss, retain_graph=False): scaled_loss = loss * self.loss_scale Loading
megatron/model/gpt2_model.py +1 −2 Original line number Diff line number Diff line Loading @@ -27,8 +27,7 @@ from .utils import scaled_init_method_normal def gpt2_attention_mask_func(attention_scores, ltor_mask): attention_scores = torch.mul(attention_scores, ltor_mask) - \ 10000.0 * (1.0 - ltor_mask) attention_scores.masked_fill_(ltor_mask, -10000.0) return attention_scores Loading
megatron/text_generation_utils.py +1 −2 Original line number Diff line number Diff line Loading @@ -42,8 +42,7 @@ def get_batch(context_tokens): tokenizer.eod, args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss, args.fp16) args.eod_mask_loss) return tokens, attention_mask, position_ids Loading