Loading megatron/arguments.py +1 −1 Original line number Diff line number Diff line Loading @@ -280,7 +280,7 @@ def _add_learning_rate_args(parser): 'and initial warmup, the learing rate at each ' 'iteration would be different.') group.add_argument('--lr-decay-style', type=str, default='linear', choices=['constant', 'linear', 'cosine', 'exponential'], choices=['constant', 'linear', 'cosine'], help='Learning rate decay function.') group.add_argument('--lr-decay-iters', type=int, default=None, help='number of iterations to decay learning rate over,' Loading megatron/learning_rates.py +82 −39 Original line number Diff line number Diff line Loading @@ -19,77 +19,101 @@ import math from megatron import print_rank_0 class AnnealingLR(object): """Anneals the learning rate.""" def __init__(self, optimizer, start_lr, warmup_iter, total_iters, decay_style, last_iter, min_lr=0.0, def __init__(self, optimizer, max_lr, min_lr, warmup_steps, decay_steps, decay_style, num_steps, use_checkpoint_lr_scheduler=True, override_lr_scheduler=False): # Class values. self.optimizer = optimizer self.start_lr = start_lr self.max_lr = float(max_lr) self.min_lr = min_lr self.warmup_iter = warmup_iter self.num_iters = last_iter self.end_iter = total_iters assert self.end_iter > 0 assert self.min_lr >= 0.0 assert self.max_lr >= self.min_lr self.warmup_steps = warmup_steps self.num_steps = num_steps self.decay_steps = decay_steps assert self.decay_steps > 0 assert self.warmup_steps < self.decay_steps self.decay_style = decay_style self.override_lr_scheduler = override_lr_scheduler self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler if self.override_lr_scheduler: assert not self.use_checkpoint_lr_scheduler, 'both override and '\ 'use-checkpoint are set.' # Set the learning rate self.step(self.num_iters) self.step(step_num=self.num_steps) print_rank_0('> learning rate decay style: {}'.format(self.decay_style)) def get_lr(self): """Learning rate decay functions from: https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" num_iters_ = min(self.num_iters, self.end_iter - self.warmup_iter) # Warmup. if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter: return float(self.start_lr) * num_iters_ / self.warmup_iter # Use linear warmup for the initial part. if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps: return self.max_lr * float(self.num_steps) / \ float(self.warmup_steps) # If the learning rate is constant, just return the initial value. if self.decay_style == 'constant': return self.max_lr # For any steps larger than `self.decay_steps`, use `self.min_lr`. if self.num_steps > self.decay_steps: return self.min_lr # If we are done with the warmup period, use the decay style. num_steps_ = self.num_steps - self.warmup_steps decay_steps_ = self.decay_steps - self.warmup_steps decay_ratio = float(num_steps_) / float(decay_steps_) assert decay_ratio >= 0.0 assert decay_ratio <= 1.0 delta_lr = self.max_lr - self.min_lr num_iters_ = num_iters_ - self.warmup_iter if self.decay_style == 'linear': lr = self.start_lr * (self.end_iter - num_iters_) / self.end_iter coeff = (1.0 - decay_ratio) elif self.decay_style == 'cosine': lr = self.start_lr / 2.0 * (math.cos( math.pi * num_iters_ / self.end_iter) + 1) elif self.decay_style == 'exponential': # exp(-0.693) = 1/2 lr = self.start_lr * math.exp(-0.693 * num_iters_ / self.end_iter) coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) else: lr = self.start_lr return max(lr, self.min_lr) raise Exception('{} decay style is not supported.'.format( self.decay_style)) return self.min_lr + coeff * delta_lr def step(self, step_num=None): def step(self, increment=1, step_num=None): """Set lr for all parameters groups.""" if step_num is None: step_num = self.num_iters + 1 self.num_iters = step_num step_num = self.num_steps + increment self.num_steps = step_num new_lr = self.get_lr() for group in self.optimizer.param_groups: group['lr'] = new_lr def state_dict(self): state_dict = { 'start_lr': self.start_lr, 'warmup_iter': self.warmup_iter, 'num_iters': self.num_iters, 'max_lr': self.max_lr, 'warmup_steps': self.warmup_steps, 'num_steps': self.num_steps, 'decay_style': self.decay_style, 'end_iter': self.end_iter, 'decay_steps': self.decay_steps, 'min_lr': self.min_lr } return state_dict def _check_and_set(self, cls_value, sd_value, name): """Auxiliary function for checking the values in the checkpoint and setting them.""" Loading @@ -104,20 +128,39 @@ class AnnealingLR(object): name)) return sd_value def load_state_dict(self, sd): self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'], if 'start_lr' in sd: max_lr_ = sd['start_lr'] else: max_lr_ = sd['max_lr'] self.max_lr = self._check_and_set(self.max_lr, max_lr_, 'learning rate') self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'], 'minimum learning rate') self.warmup_iter = self._check_and_set(self.warmup_iter, sd['warmup_iter'], if 'warmup_iter' in sd: warmup_steps_ = sd['warmup_iter'] else: warmup_steps_ = sd['warmup_steps'] self.warmup_steps = self._check_and_set(self.warmup_steps, warmup_steps_, 'warmup iterations') self.end_iter = self._check_and_set(self.end_iter, sd['end_iter'], if 'end_iter' in sd: decay_steps_ = sd['end_iter'] else: decay_steps_ = sd['decay_steps'] self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_, 'total number of iterations') self.decay_style = self._check_and_set(self.decay_style, sd['decay_style'], 'decay style') self.num_iters = sd['num_iters'] self.step(self.num_iters) if 'num_iters' in sd: self.num_steps = sd['num_iters'] else: self.num_steps = sd['num_steps'] self.step(step_num=self.num_steps) megatron/training.py +5 −5 Original line number Diff line number Diff line Loading @@ -194,12 +194,12 @@ def get_learning_rate_scheduler(optimizer): warmup_iter = args.warmup * num_iters lr_scheduler = AnnealingLR( optimizer, start_lr=args.lr, warmup_iter=warmup_iter, total_iters=num_iters, decay_style=args.lr_decay_style, last_iter=init_step, max_lr=args.lr, min_lr=args.min_lr, warmup_steps=warmup_iter, decay_steps=num_iters, decay_style=args.lr_decay_style, num_steps=init_step, use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler, override_lr_scheduler=args.override_lr_scheduler) Loading Loading
megatron/arguments.py +1 −1 Original line number Diff line number Diff line Loading @@ -280,7 +280,7 @@ def _add_learning_rate_args(parser): 'and initial warmup, the learing rate at each ' 'iteration would be different.') group.add_argument('--lr-decay-style', type=str, default='linear', choices=['constant', 'linear', 'cosine', 'exponential'], choices=['constant', 'linear', 'cosine'], help='Learning rate decay function.') group.add_argument('--lr-decay-iters', type=int, default=None, help='number of iterations to decay learning rate over,' Loading
megatron/learning_rates.py +82 −39 Original line number Diff line number Diff line Loading @@ -19,77 +19,101 @@ import math from megatron import print_rank_0 class AnnealingLR(object): """Anneals the learning rate.""" def __init__(self, optimizer, start_lr, warmup_iter, total_iters, decay_style, last_iter, min_lr=0.0, def __init__(self, optimizer, max_lr, min_lr, warmup_steps, decay_steps, decay_style, num_steps, use_checkpoint_lr_scheduler=True, override_lr_scheduler=False): # Class values. self.optimizer = optimizer self.start_lr = start_lr self.max_lr = float(max_lr) self.min_lr = min_lr self.warmup_iter = warmup_iter self.num_iters = last_iter self.end_iter = total_iters assert self.end_iter > 0 assert self.min_lr >= 0.0 assert self.max_lr >= self.min_lr self.warmup_steps = warmup_steps self.num_steps = num_steps self.decay_steps = decay_steps assert self.decay_steps > 0 assert self.warmup_steps < self.decay_steps self.decay_style = decay_style self.override_lr_scheduler = override_lr_scheduler self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler if self.override_lr_scheduler: assert not self.use_checkpoint_lr_scheduler, 'both override and '\ 'use-checkpoint are set.' # Set the learning rate self.step(self.num_iters) self.step(step_num=self.num_steps) print_rank_0('> learning rate decay style: {}'.format(self.decay_style)) def get_lr(self): """Learning rate decay functions from: https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" num_iters_ = min(self.num_iters, self.end_iter - self.warmup_iter) # Warmup. if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter: return float(self.start_lr) * num_iters_ / self.warmup_iter # Use linear warmup for the initial part. if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps: return self.max_lr * float(self.num_steps) / \ float(self.warmup_steps) # If the learning rate is constant, just return the initial value. if self.decay_style == 'constant': return self.max_lr # For any steps larger than `self.decay_steps`, use `self.min_lr`. if self.num_steps > self.decay_steps: return self.min_lr # If we are done with the warmup period, use the decay style. num_steps_ = self.num_steps - self.warmup_steps decay_steps_ = self.decay_steps - self.warmup_steps decay_ratio = float(num_steps_) / float(decay_steps_) assert decay_ratio >= 0.0 assert decay_ratio <= 1.0 delta_lr = self.max_lr - self.min_lr num_iters_ = num_iters_ - self.warmup_iter if self.decay_style == 'linear': lr = self.start_lr * (self.end_iter - num_iters_) / self.end_iter coeff = (1.0 - decay_ratio) elif self.decay_style == 'cosine': lr = self.start_lr / 2.0 * (math.cos( math.pi * num_iters_ / self.end_iter) + 1) elif self.decay_style == 'exponential': # exp(-0.693) = 1/2 lr = self.start_lr * math.exp(-0.693 * num_iters_ / self.end_iter) coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) else: lr = self.start_lr return max(lr, self.min_lr) raise Exception('{} decay style is not supported.'.format( self.decay_style)) return self.min_lr + coeff * delta_lr def step(self, step_num=None): def step(self, increment=1, step_num=None): """Set lr for all parameters groups.""" if step_num is None: step_num = self.num_iters + 1 self.num_iters = step_num step_num = self.num_steps + increment self.num_steps = step_num new_lr = self.get_lr() for group in self.optimizer.param_groups: group['lr'] = new_lr def state_dict(self): state_dict = { 'start_lr': self.start_lr, 'warmup_iter': self.warmup_iter, 'num_iters': self.num_iters, 'max_lr': self.max_lr, 'warmup_steps': self.warmup_steps, 'num_steps': self.num_steps, 'decay_style': self.decay_style, 'end_iter': self.end_iter, 'decay_steps': self.decay_steps, 'min_lr': self.min_lr } return state_dict def _check_and_set(self, cls_value, sd_value, name): """Auxiliary function for checking the values in the checkpoint and setting them.""" Loading @@ -104,20 +128,39 @@ class AnnealingLR(object): name)) return sd_value def load_state_dict(self, sd): self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'], if 'start_lr' in sd: max_lr_ = sd['start_lr'] else: max_lr_ = sd['max_lr'] self.max_lr = self._check_and_set(self.max_lr, max_lr_, 'learning rate') self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'], 'minimum learning rate') self.warmup_iter = self._check_and_set(self.warmup_iter, sd['warmup_iter'], if 'warmup_iter' in sd: warmup_steps_ = sd['warmup_iter'] else: warmup_steps_ = sd['warmup_steps'] self.warmup_steps = self._check_and_set(self.warmup_steps, warmup_steps_, 'warmup iterations') self.end_iter = self._check_and_set(self.end_iter, sd['end_iter'], if 'end_iter' in sd: decay_steps_ = sd['end_iter'] else: decay_steps_ = sd['decay_steps'] self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_, 'total number of iterations') self.decay_style = self._check_and_set(self.decay_style, sd['decay_style'], 'decay style') self.num_iters = sd['num_iters'] self.step(self.num_iters) if 'num_iters' in sd: self.num_steps = sd['num_iters'] else: self.num_steps = sd['num_steps'] self.step(step_num=self.num_steps)
megatron/training.py +5 −5 Original line number Diff line number Diff line Loading @@ -194,12 +194,12 @@ def get_learning_rate_scheduler(optimizer): warmup_iter = args.warmup * num_iters lr_scheduler = AnnealingLR( optimizer, start_lr=args.lr, warmup_iter=warmup_iter, total_iters=num_iters, decay_style=args.lr_decay_style, last_iter=init_step, max_lr=args.lr, min_lr=args.min_lr, warmup_steps=warmup_iter, decay_steps=num_iters, decay_style=args.lr_decay_style, num_steps=init_step, use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler, override_lr_scheduler=args.override_lr_scheduler) Loading