Loading tasks/finetune_utils.py +11 −1 Original line number Diff line number Diff line Loading @@ -16,7 +16,7 @@ """Finetune utilities.""" from functools import partial import sys import torch from megatron import get_args Loading Loading @@ -215,9 +215,11 @@ def _train(model, optimizer, lr_scheduler, forward_step, optimizer, lr_scheduler) # Checkpointing saved_checkpoint = False if args.save and args.save_interval and \ iteration % args.save_interval == 0: save_checkpoint(iteration, model, optimizer, lr_scheduler) saved_checkpoint = True # Evaluation if args.eval_interval and iteration % args.eval_interval == 0: Loading @@ -226,6 +228,14 @@ def _train(model, optimizer, lr_scheduler, forward_step, valid_dataloader, model, iteration, False) # Exiting based on iterations if args.exit_interval and iteration % args.exit_interval == 0: if not saved_checkpoint: save_checkpoint(iteration, model, optimizer, lr_scheduler) torch.distributed.barrier() print_rank_0('exiting program at iteration {}'.format(iteration)) sys.exit() # Checkpointing at the end of each epoch. if args.save: save_checkpoint(iteration, model, optimizer, lr_scheduler) Loading Loading
tasks/finetune_utils.py +11 −1 Original line number Diff line number Diff line Loading @@ -16,7 +16,7 @@ """Finetune utilities.""" from functools import partial import sys import torch from megatron import get_args Loading Loading @@ -215,9 +215,11 @@ def _train(model, optimizer, lr_scheduler, forward_step, optimizer, lr_scheduler) # Checkpointing saved_checkpoint = False if args.save and args.save_interval and \ iteration % args.save_interval == 0: save_checkpoint(iteration, model, optimizer, lr_scheduler) saved_checkpoint = True # Evaluation if args.eval_interval and iteration % args.eval_interval == 0: Loading @@ -226,6 +228,14 @@ def _train(model, optimizer, lr_scheduler, forward_step, valid_dataloader, model, iteration, False) # Exiting based on iterations if args.exit_interval and iteration % args.exit_interval == 0: if not saved_checkpoint: save_checkpoint(iteration, model, optimizer, lr_scheduler) torch.distributed.barrier() print_rank_0('exiting program at iteration {}'.format(iteration)) sys.exit() # Checkpointing at the end of each epoch. if args.save: save_checkpoint(iteration, model, optimizer, lr_scheduler) Loading