Commit e2a4d426 authored by mohammad's avatar mohammad
Browse files

found a bug in consumed tokens initialization

parent 75bd9b54
Loading
Loading
Loading
Loading
+4 −2
Original line number Diff line number Diff line
@@ -220,8 +220,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
    if 'args' in state_dict:
        checkpoint_args = state_dict['args']
        check_checkpoint_args(checkpoint_args)
        args.consumed_train_samples = getattr(args, 'consumed_train_samples', 0)
        args.consumed_valid_samples = getattr(args, 'consumed_valid_samples', 0)
        args.consumed_train_samples = getattr(checkpoint_args,
                                              'consumed_train_samples', 0)
        args.consumed_valid_samples = getattr(checkpoint_args,
                                              'consumed_valid_samples', 0)
    else:
        print_rank_0('could not find arguments in the checkpoint ...')