Loading megatron/arguments.py +25 −4 Original line number Diff line number Diff line Loading @@ -129,11 +129,26 @@ def parse_args(extra_args_provider=None, defaults={}, # Parameters dtype. args.params_dtype = torch.float if args.fp16: assert not args.bf16 args.params_dtype = torch.half if args.bf16: assert not args.fp16 args.params_dtype = torch.bfloat16 # No fusion is support for bfloat for now assert not args.masked_softmax_fusion assert not args.bias_gelu_fusion assert not args.bias_dropout_fusion if args.rank == 0: print('using {} for parameters ...'.format(args.params_dtype), flush=True) # If we do accumulation and all-reduces in fp32, we need to have # local DDP and we should set the use-contiguous-buffers-in-ddp. if args.accumulate_allreduce_grads_in_fp32: assert args.DDP_impl == 'local' args.use_contiguous_buffers_in_ddp = True if args.dataloader_type is None: args.dataloader_type = 'single' Loading Loading @@ -204,8 +219,8 @@ def parse_args(extra_args_provider=None, defaults={}, if args.fp16_lm_cross_entropy: assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' if args.fp32_residual_connection: assert args.fp16, \ 'residual connection in fp32 only supported when using fp16.' assert args.fp16 or args.bf16, \ 'residual connection in fp32 only supported when using fp16 or bf16.' # Activation checkpointing. if args.distribute_checkpointed_activations: assert args.checkpoint_activations, \ Loading Loading @@ -528,6 +543,8 @@ def _add_mixed_precision_args(parser): group.add_argument('--fp16', action='store_true', help='Run model in fp16 mode.') group.add_argument('--bf16', action='store_true', help='Run model in bfloat16 mode.') group.add_argument('--loss-scale', type=float, default=None, help='Static loss scaling, positive power of 2 ' 'values can improve fp16 convergence. If None, dynamic' Loading @@ -549,8 +566,9 @@ def _add_mixed_precision_args(parser): help='Run attention masking and softmax in fp32. ' 'This flag is ignored unless ' '--no-query-key-layer-scaling is specified.') group.add_argument('--fp32-allreduce', action='store_true', help='All-reduce in fp32') group.add_argument('--accumulate-allreduce-grads-in-fp32', action='store_true', help='Gradient accumulation and all-reduce in fp32.') group.add_argument('--fp16-lm-cross-entropy', action='store_true', help='Move the cross entropy unreduced loss calculation' 'for lm head to fp16.') Loading @@ -577,6 +595,9 @@ def _add_distributed_args(parser): choices=['local', 'torch'], help='which DistributedDataParallel implementation ' 'to use.') group.add_argument('--use-contiguous-buffers-in-ddp', action='store_true', help='If set, use contiguous buffer in DDP. Note that ' 'this option only works woth local DDP.' ) group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false', help='Use scatter/gather to optimize communication of tensors in pipeline', dest='scatter_gather_tensors_in_pipeline') Loading megatron/model/__init__.py +5 −3 Original line number Diff line number Diff line Loading @@ -16,11 +16,13 @@ _LAYER_NORM = None def import_layernorm(fp32_residual_connection): def import_layernorm(fp32_residual_connection, bf16): global _LAYER_NORM if not _LAYER_NORM: if fp32_residual_connection: if bf16: from torch.nn import LayerNorm elif fp32_residual_connection: from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm else: from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm Loading @@ -39,6 +41,6 @@ from .gpt_model import (GPTModel, GPTModelIntermediateStage, GPTModelLastStage) from .language_model import get_language_model from .module import FP16Module from .module import Float16Module megatron/model/bert_model.py +1 −1 Original line number Diff line number Diff line Loading @@ -78,7 +78,7 @@ class BertLMHead(MegatronModule): self.parallel_output = parallel_output self.dense = get_linear_layer(hidden_size, hidden_size, init_method) LayerNorm = import_layernorm(args.fp32_residual_connection) LayerNorm = import_layernorm(args.fp32_residual_connection, args.bf16) self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.gelu = torch.nn.functional.gelu if args.openai_gelu: Loading megatron/model/distributed.py +178 −72 Original line number Diff line number Diff line Loading @@ -13,100 +13,206 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC from abc import abstractmethod import torch from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors import torch.distributed as dist from torch.nn.modules import Module from torch.autograd import Variable from megatron import get_args from megatron import mpu from .module import MegatronModule class DistributedDataParallel(MegatronModule): def __init__(self, module): super(DistributedDataParallel, self).__init__() self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False class MemoryBuffer: def __init__(self, numel, dtype): self.numel = numel self.dtype = dtype self.data = torch.zeros(self.numel, dtype=self.dtype, device=torch.cuda.current_device(), requires_grad=False) def zero(self): """Reset the buffer to zero.""" self.data.zero_() def get(self, shape, start_index): """Return a tensor with the input `shape` as a view into the 1-D data starting at `start_index`.""" end_index = start_index + shape.numel() assert end_index <= self.numel, \ 'requested tensor is out of the buffer range.' buffer_tensor = self.data[start_index:end_index] buffer_tensor = buffer_tensor.view(shape) return buffer_tensor class DistributedDataParallelBase(MegatronModule, ABC): """Abstract class for DDP.""" def __init__(self, module): super(DistributedDataParallelBase, self).__init__() # Keep a pointer to the model. self.module = module self.data_parallel_group = mpu.get_data_parallel_group() def allreduce_params(reduce_after=True, no_scale=False, fp32_allreduce=False): if(self.needs_reduction): self.needs_reduction = False buckets = {} for name, param in self.module.named_parameters(): if param.requires_grad and param.grad is not None: tp = (param.data.type()) if tp not in buckets: buckets[tp] = [] buckets[tp].append(param) if self.warn_on_half: if torch.cuda.HalfTensor in buckets: print("WARNING: gloo dist backend for half parameters may be extremely slow." + " It is recommended to use the NCCL backend in this case.") self.warn_on_half = False for tp in buckets: bucket = buckets[tp] grads = [param.grad.data for param in bucket] coalesced = _flatten_dense_tensors(grads) if fp32_allreduce: coalesced = coalesced.float() if not no_scale and not reduce_after: coalesced /= dist.get_world_size(group=self.data_parallel_group) dist.all_reduce(coalesced, group=self.data_parallel_group) torch.cuda.synchronize() if not no_scale and reduce_after: coalesced /= dist.get_world_size(group=self.data_parallel_group) for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced) self.hook_handles = [] self.hooks = [] for param in list(self.module.parameters()): def allreduce_hook(*unused): Variable._execution_engine.queue_callback(allreduce_params) # handle = param.register_hook(allreduce_hook) # self.hooks.append(allreduce_hook) # self.hook_handles.append(handle) self.allreduce_params = allreduce_params @abstractmethod def allreduce_gradients(self): pass def forward(self, *inputs, **kwargs): self.needs_reduction = True return self.module(*inputs, **kwargs) def state_dict(self, destination=None, prefix='', keep_vars=False): #[h.remove() for h in self.hook_handles] sd = self.module.state_dict(destination, prefix, keep_vars) # for handle, hook in zip(self.hook_handles, self.hooks): # d = handle.hooks_dict_ref() # d[handle.id] = hook return self.module.state_dict(destination, prefix, keep_vars) return sd def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): return self.module.state_dict_for_save_checkpoint(destination, prefix, keep_vars) def load_state_dict(self, state_dict, strict=True): self.module.load_state_dict(state_dict, strict=strict) ''' def _sync_buffers(self): buffers = list(self.module._all_buffers()) if len(buffers) > 0: # cross-node buffer sync flat_buffers = _flatten_dense_tensors(buffers) dist.broadcast(flat_buffers, 0) for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)): class DistributedDataParallel(DistributedDataParallelBase): """DDP with contiguous buffers options to storre and accumulate gradients. This class: - has the potential to reduce memory fragmentation. - provides the option to do the gradient accumulation in a type other than the params type (for example fp32) Arguments: module: input model. accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation and the gradient all-reduce all in in float32. If this option is true, we require `use_contiguous_buffers` to be true too. use_contiguous_buffers: if true, use a contiguous buffer to store the gradients. """ def __init__(self, module, accumulate_allreduce_grads_in_fp32, use_contiguous_buffers): super(DistributedDataParallel, self).__init__(module) self.accumulate_allreduce_grads_in_fp32 \ = accumulate_allreduce_grads_in_fp32 self.use_contiguous_buffers = use_contiguous_buffers # If we are using fp32-accumulate-allreduce explicitly # this means we need main grads in a continous buffer. if self.accumulate_allreduce_grads_in_fp32: assert self.use_contiguous_buffers # =================================== # Rest of this part applies only to # the case we use continuous buffers. # =================================== self._grad_buffers = None if self.use_contiguous_buffers: self._grad_buffers = {} # Simple function to define buffer type. def _get_buffer_type(param): return torch.float if \ self.accumulate_allreduce_grads_in_fp32 else param.dtype # First calculate total number of elements per type. type_num_elements = {} for param in self.module.parameters(): if param.requires_grad: dtype = _get_buffer_type(param) type_num_elements[dtype] = type_num_elements.get(dtype, 0) \ + param.data.nelement() # Allocate the buffer. for dtype, num_elements in type_num_elements.items(): self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype) # Assume the back prop order is reverse the params order, # store the start index for the gradients. for param in self.module.parameters(): if param.requires_grad: dtype = _get_buffer_type(param) type_num_elements[dtype] -= param.data.nelement() param.main_grad = self._grad_buffers[dtype].get( param.data.shape, type_num_elements[dtype]) # Backward hook. # Accumalation function for the gradients. We need # to store them so they don't go out of scope. self.grad_accs = [] # Loop over all the parameters in the model. for param in self.module.parameters(): if param.requires_grad: # Expand so we get access to grad_fn. param_tmp = param.expand_as(param) # Get the gradient accumulator functtion. grad_acc = param_tmp.grad_fn.next_functions[0][0] grad_acc.register_hook(self._make_param_hook(param)) self.grad_accs.append(grad_acc) def _make_param_hook(self, param): """Create the all-reduce hook for backprop.""" # Hook used for back-prop. def param_hook(*unused): # Add the gradient to the buffer. if param.grad.data is not None: param.main_grad.add_(param.grad.data) # Now we can deallocate grad memory. param.grad = None return param_hook def zero_grad_buffer(self): """Set the grad buffer data to zero. Needs to be called at the begining of each iteration.""" assert self._grad_buffers is not None, 'buffers are not initialized.' for _, buffer_ in self._grad_buffers.items(): buffer_.zero() def allreduce_gradients(self): """Reduce gradients across data parallel ranks.""" # If we have buffers, simply reduce the data in the buffer. if self._grad_buffers is not None: for _, buffer_ in self._grad_buffers.items(): buffer_.data /= mpu.get_data_parallel_world_size() torch.distributed.all_reduce( buffer_.data, group=mpu.get_data_parallel_group()) else: # Otherwise, bucketize and all-reduce buckets = {} # Pack the buckets. for param in self.module.parameters(): if param.requires_grad and param.grad is not None: tp = param.data.type() if tp not in buckets: buckets[tp] = [] buckets[tp].append(param) param.main_grad = param.grad # For each bucket, all-reduce and copy all-reduced grads. for tp in buckets: bucket = buckets[tp] grads = [param.grad.data for param in bucket] coalesced = _flatten_dense_tensors(grads) coalesced /= mpu.get_data_parallel_world_size() torch.distributed.all_reduce( coalesced, group=mpu.get_data_parallel_group()) for buf, synced in zip(grads, _unflatten_dense_tensors( coalesced, grads)): buf.copy_(synced) def train(self, mode=True): # Clear NCCL communicator and CUDA event cache of the default group ID, # These cache will be recreated at the later call. This is currently a # work-around for a potential NCCL deadlock. if dist._backend == dist.dist_backend.NCCL: dist._clear_group_cache() super(DistributedDataParallel, self).train(mode) self.module.train(mode) ''' megatron/model/module.py +26 −12 Original line number Diff line number Diff line Loading @@ -25,6 +25,7 @@ from megatron import mpu _FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) _HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) _BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor) Loading Loading @@ -109,6 +110,7 @@ class MegatronModule(torch.nn.Module): "this needs to be handled manually. If you are training " "something is definitely wrong.") def conversion_helper(val, conversion): """Apply conversion to val. Recursively apply conversion if `val` #is a nested tuple/list structure.""" Loading @@ -120,44 +122,56 @@ def conversion_helper(val, conversion): return rtn def fp32_to_fp16(val): """Convert fp32 `val` to fp16""" def fp32_to_float16(val, float16_convertor): """Convert fp32 `val` to fp16/bf16""" def half_conversion(val): val_typecheck = val if isinstance(val_typecheck, (Parameter, Variable)): val_typecheck = val.data if isinstance(val_typecheck, _FLOAT_TYPES): val = val.half() val = float16_convertor(val) return val return conversion_helper(val, half_conversion) def fp16_to_fp32(val): """Convert fp16 `val` to fp32""" def float16_to_fp32(val): """Convert fp16/bf16 `val` to fp32""" def float_conversion(val): val_typecheck = val if isinstance(val_typecheck, (Parameter, Variable)): val_typecheck = val.data if isinstance(val_typecheck, _HALF_TYPES): if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)): val = val.float() return val return conversion_helper(val, float_conversion) class FP16Module(MegatronModule): class Float16Module(MegatronModule): def __init__(self, module, args): super(Float16Module, self).__init__() def __init__(self, module): super(FP16Module, self).__init__() if args.fp16: self.add_module('module', module.half()) def float16_convertor(val): return val.half() elif args.bf16: self.add_module('module', module.bfloat16()) def float16_convertor(val): return val.bfloat16() else: raise Exception('should not be here') self.float16_convertor = float16_convertor def forward(self, *inputs, **kwargs): if mpu.is_pipeline_first_stage(): inputs = fp32_to_fp16(inputs) inputs = fp32_to_float16(inputs, self.float16_convertor) outputs = self.module(*inputs, **kwargs) if mpu.is_pipeline_last_stage(): outputs = fp16_to_fp32(outputs) outputs = float16_to_fp32(outputs) return outputs Loading Loading
megatron/arguments.py +25 −4 Original line number Diff line number Diff line Loading @@ -129,11 +129,26 @@ def parse_args(extra_args_provider=None, defaults={}, # Parameters dtype. args.params_dtype = torch.float if args.fp16: assert not args.bf16 args.params_dtype = torch.half if args.bf16: assert not args.fp16 args.params_dtype = torch.bfloat16 # No fusion is support for bfloat for now assert not args.masked_softmax_fusion assert not args.bias_gelu_fusion assert not args.bias_dropout_fusion if args.rank == 0: print('using {} for parameters ...'.format(args.params_dtype), flush=True) # If we do accumulation and all-reduces in fp32, we need to have # local DDP and we should set the use-contiguous-buffers-in-ddp. if args.accumulate_allreduce_grads_in_fp32: assert args.DDP_impl == 'local' args.use_contiguous_buffers_in_ddp = True if args.dataloader_type is None: args.dataloader_type = 'single' Loading Loading @@ -204,8 +219,8 @@ def parse_args(extra_args_provider=None, defaults={}, if args.fp16_lm_cross_entropy: assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' if args.fp32_residual_connection: assert args.fp16, \ 'residual connection in fp32 only supported when using fp16.' assert args.fp16 or args.bf16, \ 'residual connection in fp32 only supported when using fp16 or bf16.' # Activation checkpointing. if args.distribute_checkpointed_activations: assert args.checkpoint_activations, \ Loading Loading @@ -528,6 +543,8 @@ def _add_mixed_precision_args(parser): group.add_argument('--fp16', action='store_true', help='Run model in fp16 mode.') group.add_argument('--bf16', action='store_true', help='Run model in bfloat16 mode.') group.add_argument('--loss-scale', type=float, default=None, help='Static loss scaling, positive power of 2 ' 'values can improve fp16 convergence. If None, dynamic' Loading @@ -549,8 +566,9 @@ def _add_mixed_precision_args(parser): help='Run attention masking and softmax in fp32. ' 'This flag is ignored unless ' '--no-query-key-layer-scaling is specified.') group.add_argument('--fp32-allreduce', action='store_true', help='All-reduce in fp32') group.add_argument('--accumulate-allreduce-grads-in-fp32', action='store_true', help='Gradient accumulation and all-reduce in fp32.') group.add_argument('--fp16-lm-cross-entropy', action='store_true', help='Move the cross entropy unreduced loss calculation' 'for lm head to fp16.') Loading @@ -577,6 +595,9 @@ def _add_distributed_args(parser): choices=['local', 'torch'], help='which DistributedDataParallel implementation ' 'to use.') group.add_argument('--use-contiguous-buffers-in-ddp', action='store_true', help='If set, use contiguous buffer in DDP. Note that ' 'this option only works woth local DDP.' ) group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false', help='Use scatter/gather to optimize communication of tensors in pipeline', dest='scatter_gather_tensors_in_pipeline') Loading
megatron/model/__init__.py +5 −3 Original line number Diff line number Diff line Loading @@ -16,11 +16,13 @@ _LAYER_NORM = None def import_layernorm(fp32_residual_connection): def import_layernorm(fp32_residual_connection, bf16): global _LAYER_NORM if not _LAYER_NORM: if fp32_residual_connection: if bf16: from torch.nn import LayerNorm elif fp32_residual_connection: from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm else: from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm Loading @@ -39,6 +41,6 @@ from .gpt_model import (GPTModel, GPTModelIntermediateStage, GPTModelLastStage) from .language_model import get_language_model from .module import FP16Module from .module import Float16Module
megatron/model/bert_model.py +1 −1 Original line number Diff line number Diff line Loading @@ -78,7 +78,7 @@ class BertLMHead(MegatronModule): self.parallel_output = parallel_output self.dense = get_linear_layer(hidden_size, hidden_size, init_method) LayerNorm = import_layernorm(args.fp32_residual_connection) LayerNorm = import_layernorm(args.fp32_residual_connection, args.bf16) self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.gelu = torch.nn.functional.gelu if args.openai_gelu: Loading
megatron/model/distributed.py +178 −72 Original line number Diff line number Diff line Loading @@ -13,100 +13,206 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC from abc import abstractmethod import torch from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors import torch.distributed as dist from torch.nn.modules import Module from torch.autograd import Variable from megatron import get_args from megatron import mpu from .module import MegatronModule class DistributedDataParallel(MegatronModule): def __init__(self, module): super(DistributedDataParallel, self).__init__() self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False class MemoryBuffer: def __init__(self, numel, dtype): self.numel = numel self.dtype = dtype self.data = torch.zeros(self.numel, dtype=self.dtype, device=torch.cuda.current_device(), requires_grad=False) def zero(self): """Reset the buffer to zero.""" self.data.zero_() def get(self, shape, start_index): """Return a tensor with the input `shape` as a view into the 1-D data starting at `start_index`.""" end_index = start_index + shape.numel() assert end_index <= self.numel, \ 'requested tensor is out of the buffer range.' buffer_tensor = self.data[start_index:end_index] buffer_tensor = buffer_tensor.view(shape) return buffer_tensor class DistributedDataParallelBase(MegatronModule, ABC): """Abstract class for DDP.""" def __init__(self, module): super(DistributedDataParallelBase, self).__init__() # Keep a pointer to the model. self.module = module self.data_parallel_group = mpu.get_data_parallel_group() def allreduce_params(reduce_after=True, no_scale=False, fp32_allreduce=False): if(self.needs_reduction): self.needs_reduction = False buckets = {} for name, param in self.module.named_parameters(): if param.requires_grad and param.grad is not None: tp = (param.data.type()) if tp not in buckets: buckets[tp] = [] buckets[tp].append(param) if self.warn_on_half: if torch.cuda.HalfTensor in buckets: print("WARNING: gloo dist backend for half parameters may be extremely slow." + " It is recommended to use the NCCL backend in this case.") self.warn_on_half = False for tp in buckets: bucket = buckets[tp] grads = [param.grad.data for param in bucket] coalesced = _flatten_dense_tensors(grads) if fp32_allreduce: coalesced = coalesced.float() if not no_scale and not reduce_after: coalesced /= dist.get_world_size(group=self.data_parallel_group) dist.all_reduce(coalesced, group=self.data_parallel_group) torch.cuda.synchronize() if not no_scale and reduce_after: coalesced /= dist.get_world_size(group=self.data_parallel_group) for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced) self.hook_handles = [] self.hooks = [] for param in list(self.module.parameters()): def allreduce_hook(*unused): Variable._execution_engine.queue_callback(allreduce_params) # handle = param.register_hook(allreduce_hook) # self.hooks.append(allreduce_hook) # self.hook_handles.append(handle) self.allreduce_params = allreduce_params @abstractmethod def allreduce_gradients(self): pass def forward(self, *inputs, **kwargs): self.needs_reduction = True return self.module(*inputs, **kwargs) def state_dict(self, destination=None, prefix='', keep_vars=False): #[h.remove() for h in self.hook_handles] sd = self.module.state_dict(destination, prefix, keep_vars) # for handle, hook in zip(self.hook_handles, self.hooks): # d = handle.hooks_dict_ref() # d[handle.id] = hook return self.module.state_dict(destination, prefix, keep_vars) return sd def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): return self.module.state_dict_for_save_checkpoint(destination, prefix, keep_vars) def load_state_dict(self, state_dict, strict=True): self.module.load_state_dict(state_dict, strict=strict) ''' def _sync_buffers(self): buffers = list(self.module._all_buffers()) if len(buffers) > 0: # cross-node buffer sync flat_buffers = _flatten_dense_tensors(buffers) dist.broadcast(flat_buffers, 0) for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)): class DistributedDataParallel(DistributedDataParallelBase): """DDP with contiguous buffers options to storre and accumulate gradients. This class: - has the potential to reduce memory fragmentation. - provides the option to do the gradient accumulation in a type other than the params type (for example fp32) Arguments: module: input model. accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation and the gradient all-reduce all in in float32. If this option is true, we require `use_contiguous_buffers` to be true too. use_contiguous_buffers: if true, use a contiguous buffer to store the gradients. """ def __init__(self, module, accumulate_allreduce_grads_in_fp32, use_contiguous_buffers): super(DistributedDataParallel, self).__init__(module) self.accumulate_allreduce_grads_in_fp32 \ = accumulate_allreduce_grads_in_fp32 self.use_contiguous_buffers = use_contiguous_buffers # If we are using fp32-accumulate-allreduce explicitly # this means we need main grads in a continous buffer. if self.accumulate_allreduce_grads_in_fp32: assert self.use_contiguous_buffers # =================================== # Rest of this part applies only to # the case we use continuous buffers. # =================================== self._grad_buffers = None if self.use_contiguous_buffers: self._grad_buffers = {} # Simple function to define buffer type. def _get_buffer_type(param): return torch.float if \ self.accumulate_allreduce_grads_in_fp32 else param.dtype # First calculate total number of elements per type. type_num_elements = {} for param in self.module.parameters(): if param.requires_grad: dtype = _get_buffer_type(param) type_num_elements[dtype] = type_num_elements.get(dtype, 0) \ + param.data.nelement() # Allocate the buffer. for dtype, num_elements in type_num_elements.items(): self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype) # Assume the back prop order is reverse the params order, # store the start index for the gradients. for param in self.module.parameters(): if param.requires_grad: dtype = _get_buffer_type(param) type_num_elements[dtype] -= param.data.nelement() param.main_grad = self._grad_buffers[dtype].get( param.data.shape, type_num_elements[dtype]) # Backward hook. # Accumalation function for the gradients. We need # to store them so they don't go out of scope. self.grad_accs = [] # Loop over all the parameters in the model. for param in self.module.parameters(): if param.requires_grad: # Expand so we get access to grad_fn. param_tmp = param.expand_as(param) # Get the gradient accumulator functtion. grad_acc = param_tmp.grad_fn.next_functions[0][0] grad_acc.register_hook(self._make_param_hook(param)) self.grad_accs.append(grad_acc) def _make_param_hook(self, param): """Create the all-reduce hook for backprop.""" # Hook used for back-prop. def param_hook(*unused): # Add the gradient to the buffer. if param.grad.data is not None: param.main_grad.add_(param.grad.data) # Now we can deallocate grad memory. param.grad = None return param_hook def zero_grad_buffer(self): """Set the grad buffer data to zero. Needs to be called at the begining of each iteration.""" assert self._grad_buffers is not None, 'buffers are not initialized.' for _, buffer_ in self._grad_buffers.items(): buffer_.zero() def allreduce_gradients(self): """Reduce gradients across data parallel ranks.""" # If we have buffers, simply reduce the data in the buffer. if self._grad_buffers is not None: for _, buffer_ in self._grad_buffers.items(): buffer_.data /= mpu.get_data_parallel_world_size() torch.distributed.all_reduce( buffer_.data, group=mpu.get_data_parallel_group()) else: # Otherwise, bucketize and all-reduce buckets = {} # Pack the buckets. for param in self.module.parameters(): if param.requires_grad and param.grad is not None: tp = param.data.type() if tp not in buckets: buckets[tp] = [] buckets[tp].append(param) param.main_grad = param.grad # For each bucket, all-reduce and copy all-reduced grads. for tp in buckets: bucket = buckets[tp] grads = [param.grad.data for param in bucket] coalesced = _flatten_dense_tensors(grads) coalesced /= mpu.get_data_parallel_world_size() torch.distributed.all_reduce( coalesced, group=mpu.get_data_parallel_group()) for buf, synced in zip(grads, _unflatten_dense_tensors( coalesced, grads)): buf.copy_(synced) def train(self, mode=True): # Clear NCCL communicator and CUDA event cache of the default group ID, # These cache will be recreated at the later call. This is currently a # work-around for a potential NCCL deadlock. if dist._backend == dist.dist_backend.NCCL: dist._clear_group_cache() super(DistributedDataParallel, self).train(mode) self.module.train(mode) '''
megatron/model/module.py +26 −12 Original line number Diff line number Diff line Loading @@ -25,6 +25,7 @@ from megatron import mpu _FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) _HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) _BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor) Loading Loading @@ -109,6 +110,7 @@ class MegatronModule(torch.nn.Module): "this needs to be handled manually. If you are training " "something is definitely wrong.") def conversion_helper(val, conversion): """Apply conversion to val. Recursively apply conversion if `val` #is a nested tuple/list structure.""" Loading @@ -120,44 +122,56 @@ def conversion_helper(val, conversion): return rtn def fp32_to_fp16(val): """Convert fp32 `val` to fp16""" def fp32_to_float16(val, float16_convertor): """Convert fp32 `val` to fp16/bf16""" def half_conversion(val): val_typecheck = val if isinstance(val_typecheck, (Parameter, Variable)): val_typecheck = val.data if isinstance(val_typecheck, _FLOAT_TYPES): val = val.half() val = float16_convertor(val) return val return conversion_helper(val, half_conversion) def fp16_to_fp32(val): """Convert fp16 `val` to fp32""" def float16_to_fp32(val): """Convert fp16/bf16 `val` to fp32""" def float_conversion(val): val_typecheck = val if isinstance(val_typecheck, (Parameter, Variable)): val_typecheck = val.data if isinstance(val_typecheck, _HALF_TYPES): if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)): val = val.float() return val return conversion_helper(val, float_conversion) class FP16Module(MegatronModule): class Float16Module(MegatronModule): def __init__(self, module, args): super(Float16Module, self).__init__() def __init__(self, module): super(FP16Module, self).__init__() if args.fp16: self.add_module('module', module.half()) def float16_convertor(val): return val.half() elif args.bf16: self.add_module('module', module.bfloat16()) def float16_convertor(val): return val.bfloat16() else: raise Exception('should not be here') self.float16_convertor = float16_convertor def forward(self, *inputs, **kwargs): if mpu.is_pipeline_first_stage(): inputs = fp32_to_fp16(inputs) inputs = fp32_to_float16(inputs, self.float16_convertor) outputs = self.module(*inputs, **kwargs) if mpu.is_pipeline_last_stage(): outputs = fp16_to_fp32(outputs) outputs = float16_to_fp32(outputs) return outputs Loading