Commit 3aacd955 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'consumed_tokens_restart_fix' into 'main'

found a bug in consumed tokens initialization

See merge request ADLR/megatron-lm!182
parents 75bd9b54 e2a4d426
Loading
Loading
Loading
Loading
+4 −2
Original line number Diff line number Diff line
@@ -220,8 +220,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
    if 'args' in state_dict:
        checkpoint_args = state_dict['args']
        check_checkpoint_args(checkpoint_args)
        args.consumed_train_samples = getattr(args, 'consumed_train_samples', 0)
        args.consumed_valid_samples = getattr(args, 'consumed_valid_samples', 0)
        args.consumed_train_samples = getattr(checkpoint_args,
                                              'consumed_train_samples', 0)
        args.consumed_valid_samples = getattr(checkpoint_args,
                                              'consumed_valid_samples', 0)
    else:
        print_rank_0('could not find arguments in the checkpoint ...')