Loading megatron/learning_rates.py +60 −47 Original line number Diff line number Diff line Loading @@ -12,59 +12,68 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch DataLoader for TFRecords""" import torch from torch.optim.lr_scheduler import _LRScheduler """Learning rate decay functions.""" import math from megatron import print_rank_0 class AnnealingLR(_LRScheduler): """Anneals the learning rate""" DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None'] class AnnealingLR(object): """Anneals the learning rate.""" def __init__(self, optimizer, start_lr, warmup_iter, num_iters, decay_style=None, last_iter=-1, min_lr=0.0, def __init__(self, optimizer, start_lr, warmup_iter, total_iters, decay_style, last_iter, min_lr=0.0, use_checkpoint_lr_scheduler=True, override_lr_scheduler=False): # Class values. self.optimizer = optimizer self.start_lr = start_lr self.min_lr = min_lr self.warmup_iter = warmup_iter self.num_iters = last_iter + 1 self.end_iter = num_iters self.decay_style = decay_style.lower() if isinstance(decay_style, str) \ else None self.num_iters = last_iter self.end_iter = total_iters assert self.end_iter > 0 self.decay_style = decay_style self.override_lr_scheduler = override_lr_scheduler self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler if self.override_lr_scheduler: assert not self.use_checkpoint_lr_scheduler, 'both override and '\ 'use-checkpoint are set.' # Set the learning rate self.step(self.num_iters) if torch.distributed.get_rank() == 0: print('learning rate decaying', decay_style) print_rank_0('> learning rate decay style: {}'.format(self.decay_style)) def get_lr(self): # https://openreview.net/pdf?id=BJYwwY9ll pg. 4 """Learning rate decay functions from: https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" num_iters_ = min(self.num_iters, self.end_iter - self.warmup_iter) # Warmup. if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter: return float(self.start_lr) * num_iters_ / self.warmup_iter else: if self.decay_style == self.DECAY_STYLES[0]: lr = self.start_lr * ((self.end_iter - (num_iters_ - self.warmup_iter)) / self.end_iter) elif self.decay_style == self.DECAY_STYLES[1]: lr = self.start_lr / 2.0 * (math.cos(math.pi * (num_iters_ - self.warmup_iter) / self.end_iter) + 1) elif self.decay_style == self.DECAY_STYLES[2]: num_iters_ = num_iters_ - self.warmup_iter if self.decay_style == 'linear': lr = self.start_lr * (self.end_iter - num_iters_) / self.end_iter elif self.decay_style == 'cosine': lr = self.start_lr / 2.0 * (math.cos( math.pi * num_iters_ / self.end_iter) + 1) elif self.decay_style == 'exponential': # exp(-0.693) = 1/2 lr = self.start_lr * math.exp(-0.693 * (num_iters_ - self.warmup_iter) / self.end_iter) lr = self.start_lr * math.exp(-0.693 * num_iters_ / self.end_iter) else: lr = self.start_lr return max(lr, self.min_lr) def step(self, step_num=None): """Set lr for all parameters groups.""" if step_num is None: step_num = self.num_iters + 1 self.num_iters = step_num Loading @@ -72,8 +81,9 @@ class AnnealingLR(_LRScheduler): for group in self.optimizer.param_groups: group['lr'] = new_lr def state_dict(self): sd = { state_dict = { 'start_lr': self.start_lr, 'warmup_iter': self.warmup_iter, 'num_iters': self.num_iters, Loading @@ -81,14 +91,16 @@ class AnnealingLR(_LRScheduler): 'end_iter': self.end_iter, 'min_lr': self.min_lr } return sd return state_dict def check_and_set_(self, cls_value, sd_value, name): def _check_and_set(self, cls_value, sd_value, name): """Auxiliary function for checking the values in the checkpoint and setting them.""" if self.override_lr_scheduler: print_rank_0(' > overriding {} value to {}'.format(name, cls_value)) return cls_value else: if not self.use_checkpoint_lr_scheduler: assert cls_value == sd_value, 'AnnealingLR: class input value' \ 'and checkpoint values for {} do not match'.format(name) Loading @@ -96,18 +108,19 @@ class AnnealingLR(_LRScheduler): name)) return sd_value def load_state_dict(self, sd): self.start_lr = self.check_and_set_(self.start_lr, sd['start_lr'], self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'], 'learning rate') self.min_lr = self.check_and_set_(self.min_lr, sd['min_lr'], self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'], 'minimum learning rate') self.warmup_iter = self.check_and_set_(self.warmup_iter, self.warmup_iter = self._check_and_set(self.warmup_iter, sd['warmup_iter'], 'warmup iterations') self.end_iter = self.check_and_set_(self.end_iter, sd['end_iter'], self.end_iter = self._check_and_set(self.end_iter, sd['end_iter'], 'total number of iterations') self.decay_style = self.check_and_set_(self.decay_style, self.decay_style = self._check_and_set(self.decay_style, sd['decay_style'], 'decay style') Loading megatron/module.py +0 −1 Original line number Diff line number Diff line Loading @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Megatron Module""" import torch Loading megatron/training.py +2 −2 Original line number Diff line number Diff line Loading @@ -197,13 +197,13 @@ def get_learning_rate_scheduler(optimizer): else: num_iters = args.train_iters num_iters = max(1, num_iters) init_step = -1 init_step = 0 warmup_iter = args.warmup * num_iters lr_scheduler = AnnealingLR( optimizer, start_lr=args.lr, warmup_iter=warmup_iter, num_iters=num_iters, total_iters=num_iters, decay_style=args.lr_decay_style, last_iter=init_step, min_lr=args.min_lr, Loading megatron/utils.py +0 −15 Original line number Diff line number Diff line Loading @@ -89,8 +89,6 @@ def check_adlr_autoresume_termination(iteration, model, ################################################### from megatron import mpu def get_ltor_masks_and_position_ids(data, eod_token, Loading Loading @@ -148,16 +146,3 @@ def get_ltor_masks_and_position_ids(data, return attention_mask, loss_mask, position_ids def vocab_size_with_padding(num_tokens, args): after = num_tokens multiple = args.make_vocab_size_divisible_by * \ mpu.get_model_parallel_world_size() while (after % multiple) != 0: after += 1 print_rank_0('> padded vocab (size: {}) with {} dummy ' 'tokens (new size: {})'.format( num_tokens, after - num_tokens, after)) return after Loading
megatron/learning_rates.py +60 −47 Original line number Diff line number Diff line Loading @@ -12,59 +12,68 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch DataLoader for TFRecords""" import torch from torch.optim.lr_scheduler import _LRScheduler """Learning rate decay functions.""" import math from megatron import print_rank_0 class AnnealingLR(_LRScheduler): """Anneals the learning rate""" DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None'] class AnnealingLR(object): """Anneals the learning rate.""" def __init__(self, optimizer, start_lr, warmup_iter, num_iters, decay_style=None, last_iter=-1, min_lr=0.0, def __init__(self, optimizer, start_lr, warmup_iter, total_iters, decay_style, last_iter, min_lr=0.0, use_checkpoint_lr_scheduler=True, override_lr_scheduler=False): # Class values. self.optimizer = optimizer self.start_lr = start_lr self.min_lr = min_lr self.warmup_iter = warmup_iter self.num_iters = last_iter + 1 self.end_iter = num_iters self.decay_style = decay_style.lower() if isinstance(decay_style, str) \ else None self.num_iters = last_iter self.end_iter = total_iters assert self.end_iter > 0 self.decay_style = decay_style self.override_lr_scheduler = override_lr_scheduler self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler if self.override_lr_scheduler: assert not self.use_checkpoint_lr_scheduler, 'both override and '\ 'use-checkpoint are set.' # Set the learning rate self.step(self.num_iters) if torch.distributed.get_rank() == 0: print('learning rate decaying', decay_style) print_rank_0('> learning rate decay style: {}'.format(self.decay_style)) def get_lr(self): # https://openreview.net/pdf?id=BJYwwY9ll pg. 4 """Learning rate decay functions from: https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" num_iters_ = min(self.num_iters, self.end_iter - self.warmup_iter) # Warmup. if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter: return float(self.start_lr) * num_iters_ / self.warmup_iter else: if self.decay_style == self.DECAY_STYLES[0]: lr = self.start_lr * ((self.end_iter - (num_iters_ - self.warmup_iter)) / self.end_iter) elif self.decay_style == self.DECAY_STYLES[1]: lr = self.start_lr / 2.0 * (math.cos(math.pi * (num_iters_ - self.warmup_iter) / self.end_iter) + 1) elif self.decay_style == self.DECAY_STYLES[2]: num_iters_ = num_iters_ - self.warmup_iter if self.decay_style == 'linear': lr = self.start_lr * (self.end_iter - num_iters_) / self.end_iter elif self.decay_style == 'cosine': lr = self.start_lr / 2.0 * (math.cos( math.pi * num_iters_ / self.end_iter) + 1) elif self.decay_style == 'exponential': # exp(-0.693) = 1/2 lr = self.start_lr * math.exp(-0.693 * (num_iters_ - self.warmup_iter) / self.end_iter) lr = self.start_lr * math.exp(-0.693 * num_iters_ / self.end_iter) else: lr = self.start_lr return max(lr, self.min_lr) def step(self, step_num=None): """Set lr for all parameters groups.""" if step_num is None: step_num = self.num_iters + 1 self.num_iters = step_num Loading @@ -72,8 +81,9 @@ class AnnealingLR(_LRScheduler): for group in self.optimizer.param_groups: group['lr'] = new_lr def state_dict(self): sd = { state_dict = { 'start_lr': self.start_lr, 'warmup_iter': self.warmup_iter, 'num_iters': self.num_iters, Loading @@ -81,14 +91,16 @@ class AnnealingLR(_LRScheduler): 'end_iter': self.end_iter, 'min_lr': self.min_lr } return sd return state_dict def check_and_set_(self, cls_value, sd_value, name): def _check_and_set(self, cls_value, sd_value, name): """Auxiliary function for checking the values in the checkpoint and setting them.""" if self.override_lr_scheduler: print_rank_0(' > overriding {} value to {}'.format(name, cls_value)) return cls_value else: if not self.use_checkpoint_lr_scheduler: assert cls_value == sd_value, 'AnnealingLR: class input value' \ 'and checkpoint values for {} do not match'.format(name) Loading @@ -96,18 +108,19 @@ class AnnealingLR(_LRScheduler): name)) return sd_value def load_state_dict(self, sd): self.start_lr = self.check_and_set_(self.start_lr, sd['start_lr'], self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'], 'learning rate') self.min_lr = self.check_and_set_(self.min_lr, sd['min_lr'], self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'], 'minimum learning rate') self.warmup_iter = self.check_and_set_(self.warmup_iter, self.warmup_iter = self._check_and_set(self.warmup_iter, sd['warmup_iter'], 'warmup iterations') self.end_iter = self.check_and_set_(self.end_iter, sd['end_iter'], self.end_iter = self._check_and_set(self.end_iter, sd['end_iter'], 'total number of iterations') self.decay_style = self.check_and_set_(self.decay_style, self.decay_style = self._check_and_set(self.decay_style, sd['decay_style'], 'decay style') Loading
megatron/module.py +0 −1 Original line number Diff line number Diff line Loading @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Megatron Module""" import torch Loading
megatron/training.py +2 −2 Original line number Diff line number Diff line Loading @@ -197,13 +197,13 @@ def get_learning_rate_scheduler(optimizer): else: num_iters = args.train_iters num_iters = max(1, num_iters) init_step = -1 init_step = 0 warmup_iter = args.warmup * num_iters lr_scheduler = AnnealingLR( optimizer, start_lr=args.lr, warmup_iter=warmup_iter, num_iters=num_iters, total_iters=num_iters, decay_style=args.lr_decay_style, last_iter=init_step, min_lr=args.min_lr, Loading
megatron/utils.py +0 −15 Original line number Diff line number Diff line Loading @@ -89,8 +89,6 @@ def check_adlr_autoresume_termination(iteration, model, ################################################### from megatron import mpu def get_ltor_masks_and_position_ids(data, eod_token, Loading Loading @@ -148,16 +146,3 @@ def get_ltor_masks_and_position_ids(data, return attention_mask, loss_mask, position_ids def vocab_size_with_padding(num_tokens, args): after = num_tokens multiple = args.make_vocab_size_divisible_by * \ mpu.get_model_parallel_world_size() while (after % multiple) != 0: after += 1 print_rank_0('> padded vocab (size: {}) with {} dummy ' 'tokens (new size: {})'.format( num_tokens, after - num_tokens, after)) return after