Loading megatron/arguments.py +5 −0 Original line number Diff line number Diff line Loading @@ -154,6 +154,11 @@ def parse_args(extra_args_provider=None, defaults={}, assert args.DDP_impl == 'local' args.use_contiguous_buffers_in_ddp = True # If we use a contiguous buffer to hold main grads, we need to have # local DDP. if args.use_contiguous_buffers_in_ddp: assert args.DDP_impl == 'local' if args.dataloader_type is None: args.dataloader_type = 'single' Loading megatron/optimizer/optimizer.py +5 −1 Original line number Diff line number Diff line Loading @@ -80,6 +80,9 @@ class MegatronOptimizer(ABC): self.params_have_main_grad = params_have_main_grad self.use_contiguous_buffers_in_ddp = use_contiguous_buffers_in_ddp if self.use_contiguous_buffers_in_ddp: assert self.params_have_main_grad, \ "use of contiguous buffer requires that params have main grad" def get_parameters(self): params = [] Loading Loading @@ -319,7 +322,8 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): # (If using contiguous buffers, main_grad's memory should # persist and therefore should not be deallocated.) model_param.grad = None if not self.use_contiguous_buffers_in_ddp: if self.params_have_main_grad and \ not self.use_contiguous_buffers_in_ddp: model_param.main_grad = None # For fp32 grads, we need to reset the grads to main grad. Loading Loading
megatron/arguments.py +5 −0 Original line number Diff line number Diff line Loading @@ -154,6 +154,11 @@ def parse_args(extra_args_provider=None, defaults={}, assert args.DDP_impl == 'local' args.use_contiguous_buffers_in_ddp = True # If we use a contiguous buffer to hold main grads, we need to have # local DDP. if args.use_contiguous_buffers_in_ddp: assert args.DDP_impl == 'local' if args.dataloader_type is None: args.dataloader_type = 'single' Loading
megatron/optimizer/optimizer.py +5 −1 Original line number Diff line number Diff line Loading @@ -80,6 +80,9 @@ class MegatronOptimizer(ABC): self.params_have_main_grad = params_have_main_grad self.use_contiguous_buffers_in_ddp = use_contiguous_buffers_in_ddp if self.use_contiguous_buffers_in_ddp: assert self.params_have_main_grad, \ "use of contiguous buffer requires that params have main grad" def get_parameters(self): params = [] Loading Loading @@ -319,7 +322,8 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): # (If using contiguous buffers, main_grad's memory should # persist and therefore should not be deallocated.) model_param.grad = None if not self.use_contiguous_buffers_in_ddp: if self.params_have_main_grad and \ not self.use_contiguous_buffers_in_ddp: model_param.main_grad = None # For fp32 grads, we need to reset the grads to main grad. Loading