Loading megatron/learning_rates.py +39 −14 Original line number Diff line number Diff line Loading @@ -19,7 +19,6 @@ import math from megatron import print_rank_0 class AnnealingLR(object): """Anneals the learning rate.""" Loading @@ -31,44 +30,67 @@ class AnnealingLR(object): # Class values. self.optimizer = optimizer self.start_lr = start_lr self.start_lr = float(start_lr) self.min_lr = min_lr assert self.min_lr >= 0.0 assert self.start_lr >= self.min_lr self.warmup_iter = warmup_iter self.num_iters = last_iter self.end_iter = total_iters assert self.end_iter > 0 assert self.warmup_iter < self.end_iter self.decay_style = decay_style self.override_lr_scheduler = override_lr_scheduler self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler if self.override_lr_scheduler: assert not self.use_checkpoint_lr_scheduler, 'both override and '\ 'use-checkpoint are set.' # Set the learning rate self.step(self.num_iters) print_rank_0('> learning rate decay style: {}'.format(self.decay_style)) def get_lr(self): """Learning rate decay functions from: https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" num_iters_ = min(self.num_iters, self.end_iter - self.warmup_iter) # Warmup. # Use linear warmup for the initial part. if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter: return float(self.start_lr) * num_iters_ / self.warmup_iter return self.start_lr * float(self.num_iters) / \ float(self.warmup_iter) # If the learning rate is constant, just return the initial value. if self.decay_style == 'constant': return self.start_lr # For any iterations larger than `self.end_iter`, use `self.min_lr`. if self.num_iters > self.end_iter: return self.min_lr # If we are done with the warmup period, use the decay style. current_iter = self.num_iters - self.warmup_iter decay_iters = self.end_iter - self.warmup_iter decay_ratio = float(current_iter) / float(decay_iters) assert decay_ratio >= 0.0 assert decay_ratio <= 1.0 delta_lr = self.start_lr - self.min_lr num_iters_ = num_iters_ - self.warmup_iter if self.decay_style == 'linear': lr = self.start_lr * (self.end_iter - num_iters_) / self.end_iter coeff = (1.0 - decay_ratio) elif self.decay_style == 'cosine': lr = self.start_lr / 2.0 * (math.cos( math.pi * num_iters_ / self.end_iter) + 1) elif self.decay_style == 'exponential': # exp(-0.693) = 1/2 lr = self.start_lr * math.exp(-0.693 * num_iters_ / self.end_iter) coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) else: lr = self.start_lr return max(lr, self.min_lr) raise Exception('{} decay style is not supported.'.format( self.decay_style)) return self.min_lr + coeff * delta_lr def step(self, step_num=None): """Set lr for all parameters groups.""" Loading @@ -79,6 +101,7 @@ class AnnealingLR(object): for group in self.optimizer.param_groups: group['lr'] = new_lr def state_dict(self): state_dict = { 'start_lr': self.start_lr, Loading @@ -90,6 +113,7 @@ class AnnealingLR(object): } return state_dict def _check_and_set(self, cls_value, sd_value, name): """Auxiliary function for checking the values in the checkpoint and setting them.""" Loading @@ -104,6 +128,7 @@ class AnnealingLR(object): name)) return sd_value def load_state_dict(self, sd): self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'], Loading Loading
megatron/learning_rates.py +39 −14 Original line number Diff line number Diff line Loading @@ -19,7 +19,6 @@ import math from megatron import print_rank_0 class AnnealingLR(object): """Anneals the learning rate.""" Loading @@ -31,44 +30,67 @@ class AnnealingLR(object): # Class values. self.optimizer = optimizer self.start_lr = start_lr self.start_lr = float(start_lr) self.min_lr = min_lr assert self.min_lr >= 0.0 assert self.start_lr >= self.min_lr self.warmup_iter = warmup_iter self.num_iters = last_iter self.end_iter = total_iters assert self.end_iter > 0 assert self.warmup_iter < self.end_iter self.decay_style = decay_style self.override_lr_scheduler = override_lr_scheduler self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler if self.override_lr_scheduler: assert not self.use_checkpoint_lr_scheduler, 'both override and '\ 'use-checkpoint are set.' # Set the learning rate self.step(self.num_iters) print_rank_0('> learning rate decay style: {}'.format(self.decay_style)) def get_lr(self): """Learning rate decay functions from: https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" num_iters_ = min(self.num_iters, self.end_iter - self.warmup_iter) # Warmup. # Use linear warmup for the initial part. if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter: return float(self.start_lr) * num_iters_ / self.warmup_iter return self.start_lr * float(self.num_iters) / \ float(self.warmup_iter) # If the learning rate is constant, just return the initial value. if self.decay_style == 'constant': return self.start_lr # For any iterations larger than `self.end_iter`, use `self.min_lr`. if self.num_iters > self.end_iter: return self.min_lr # If we are done with the warmup period, use the decay style. current_iter = self.num_iters - self.warmup_iter decay_iters = self.end_iter - self.warmup_iter decay_ratio = float(current_iter) / float(decay_iters) assert decay_ratio >= 0.0 assert decay_ratio <= 1.0 delta_lr = self.start_lr - self.min_lr num_iters_ = num_iters_ - self.warmup_iter if self.decay_style == 'linear': lr = self.start_lr * (self.end_iter - num_iters_) / self.end_iter coeff = (1.0 - decay_ratio) elif self.decay_style == 'cosine': lr = self.start_lr / 2.0 * (math.cos( math.pi * num_iters_ / self.end_iter) + 1) elif self.decay_style == 'exponential': # exp(-0.693) = 1/2 lr = self.start_lr * math.exp(-0.693 * num_iters_ / self.end_iter) coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) else: lr = self.start_lr return max(lr, self.min_lr) raise Exception('{} decay style is not supported.'.format( self.decay_style)) return self.min_lr + coeff * delta_lr def step(self, step_num=None): """Set lr for all parameters groups.""" Loading @@ -79,6 +101,7 @@ class AnnealingLR(object): for group in self.optimizer.param_groups: group['lr'] = new_lr def state_dict(self): state_dict = { 'start_lr': self.start_lr, Loading @@ -90,6 +113,7 @@ class AnnealingLR(object): } return state_dict def _check_and_set(self, cls_value, sd_value, name): """Auxiliary function for checking the values in the checkpoint and setting them.""" Loading @@ -104,6 +128,7 @@ class AnnealingLR(object): name)) return sd_value def load_state_dict(self, sd): self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'], Loading