Commit d078e54a authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

added exit interval for finetuning

parent 825375cf
Loading
Loading
Loading
Loading
+11 −1
Original line number Diff line number Diff line
@@ -16,7 +16,7 @@
"""Finetune utilities."""

from functools import partial

import sys
import torch

from megatron import get_args
@@ -215,9 +215,11 @@ def _train(model, optimizer, lr_scheduler, forward_step,
                                                  optimizer, lr_scheduler)

            # Checkpointing
            saved_checkpoint = False
            if args.save and args.save_interval and \
               iteration % args.save_interval == 0:
                save_checkpoint(iteration, model, optimizer, lr_scheduler)
                saved_checkpoint = True

            # Evaluation
            if args.eval_interval and iteration % args.eval_interval == 0:
@@ -226,6 +228,14 @@ def _train(model, optimizer, lr_scheduler, forward_step,
                                           valid_dataloader, model,
                                           iteration, False)

            # Exiting based on iterations
            if args.exit_interval and iteration % args.exit_interval == 0:
                if not saved_checkpoint:
                    save_checkpoint(iteration, model, optimizer, lr_scheduler)
                torch.distributed.barrier()
                print_rank_0('exiting program at iteration {}'.format(iteration))
                sys.exit()

        # Checkpointing at the end of each epoch.
        if args.save:
            save_checkpoint(iteration, model, optimizer, lr_scheduler)