Commit 2348c99a authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'validation_loss_after_checkpoint' into 'main'

Run evaluation pass before checkpoint is saved

See merge request ADLR/megatron-lm!196
parents c84f0752 13bde16f
Loading
Loading
Loading
Loading
+8 −9
Original line number Diff line number Diff line
@@ -857,15 +857,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
            check_adlr_autoresume_termination(iteration, model, optimizer,
                                              lr_scheduler)

        # Checkpointing
        saved_checkpoint = False
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
            save_checkpoint_and_time(iteration, model, optimizer,
                                     lr_scheduler)
            saved_checkpoint = True


        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
@@ -874,6 +865,14 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
                                       valid_data_iterator, model,
                                       iteration, False)

        # Checkpointing
        saved_checkpoint = False
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
            save_checkpoint_and_time(iteration, model, optimizer,
                                     lr_scheduler)
            saved_checkpoint = True

        # Exiting based on duration
        if args.exit_duration_in_mins:
            train_time = (time.time() - _TRAIN_START_TIME) / 60.0