Commit 5e9b90b5 authored by Isaac's avatar Isaac
Browse files

Removing evaluation step from pretraining loop

parent 10cacfb1
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -498,7 +498,7 @@ def _add_training_args(parser):
                       help='Disable bias and dropout fusion.',
                       dest='bias_dropout_fusion')
    group.add_argument('--optimizer', type=str, default='adam',
                       choices=['adam', 'sgd'],
                       choices=['adam', 'sgd', 'lamb'],
                       help='Optimizer function')
    group.add_argument('--dataloader-type', type=str, default=None,
                       choices=['single', 'cyclic'],
+15 −9
Original line number Diff line number Diff line
@@ -392,6 +392,7 @@ def train_step(forward_step_func, data_iterator,
    losses_reduced = forward_backward_func(
        forward_step_func, data_iterator, model,
        optimizer, timers, forward_only=False)
    #print_rank_0('End forward_func')

    # Empty unused memory
    if args.empty_unused_memory_level >= 1:
@@ -671,13 +672,16 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
                       model,
                       optimizer,
                        lr_scheduler)
        #print_rank_0('End train_step')
        iteration += 1
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
                                       args.micro_batch_size * \
                                       get_num_microbatches()
        #print_rank_0('HERE2')

        # Logging.
        loss_scale = optimizer.get_loss_scale().item()
        #print_rank_0('HERE3')
        params_norm = None
        if args.log_params_norm:
            params_norm = calc_params_l2_norm(model)
@@ -686,6 +690,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
                                          iteration, loss_scale,
                                          report_memory_flag, skipped_iter,
                                          grad_norm, params_norm, num_zeros_in_grad)
        #print_rank_0('HERE4')

        # Autoresume
        if args.adlr_autoresume and \
@@ -694,12 +699,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
                                              lr_scheduler)

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
                                       valid_data_iterator, model,
                                       iteration, False)
        #if args.eval_interval and iteration % args.eval_interval == 0 and \
        #   args.do_valid:
        #    prefix = 'iteration {}'.format(iteration)
        #    evaluate_and_print_results(prefix, forward_step_func,
        #                               valid_data_iterator, model,
        #                               iteration, False)

        # Checkpointing
        saved_checkpoint = False
@@ -873,8 +878,9 @@ def build_train_valid_test_data_iterators(
        # Build dataloders.
        train_dataloader = build_pretraining_data_loader(
            train_ds, args.consumed_train_samples)
        valid_dataloader = build_pretraining_data_loader(
            valid_ds, args.consumed_valid_samples)
        #valid_dataloader = build_pretraining_data_loader(
        #    valid_ds, args.consumed_valid_samples)
        valid_dataloader = None
        test_dataloader = build_pretraining_data_loader(test_ds, 0)

        # Flags to know if we need to do training/validation/testing.