Commit c1faa9fe authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'ckpt_rng' into 'main'

Improve handling of rng states in checkpoints.

See merge request ADLR/megatron-lm!231
parents 8863af8c 08a848c7
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -494,9 +494,9 @@ def _add_checkpointing_args(parser):
                       help='Output directory to save checkpoints to.')
    group.add_argument('--save-interval', type=int, default=None,
                       help='Number of iterations between checkpoint saves.')
    group.add_argument('--no-save-optim', action='store_true',
    group.add_argument('--no-save-optim', action='store_true', default=None,
                       help='Do not save current optimizer.')
    group.add_argument('--no-save-rng', action='store_true',
    group.add_argument('--no-save-rng', action='store_true', default=None,
                       help='Do not save current rng state.')
    group.add_argument('--load', type=str, default=None,
                       help='Directory containing a model checkpoint.')
+5 −2
Original line number Diff line number Diff line
@@ -343,12 +343,15 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
            np.random.set_state(state_dict['np_rng_state'])
            torch.set_rng_state(state_dict['torch_rng_state'])
            torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
            # Check for empty states array
            if not state_dict['rng_tracker_states']:
                raise KeyError
            mpu.get_cuda_rng_tracker().set_states(
                state_dict['rng_tracker_states'])
        except KeyError:
            print_rank_0('Unable to load optimizer from checkpoint {}. '
            print_rank_0('Unable to load rng state from checkpoint {}. '
                         'Specify --no-load-rng or --finetune to prevent '
                         'attempting to load the optimizer state, '
                         'attempting to load the rng state, '
                         'exiting ...'.format(checkpoint_name))
            sys.exit()

+3 −1
Original line number Diff line number Diff line
@@ -92,7 +92,9 @@ def main():
    """Main program."""

    initialize_megatron(extra_args_provider=add_text_generate_args,
                        args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
                        args_defaults={'tokenizer_type': 'GPT2BPETokenizer',
                                       'no_load_rng': True,
                                       'no_load_optim': True})

    # Set up model and load checkpoint.
    model = get_model(model_provider)
+2 −0
Original line number Diff line number Diff line
@@ -200,6 +200,8 @@ def main():
                                          'micro_batch_size': 1,
                                          'no_load_optim': True,
                                          'no_load_rng': True,
                                          'no_save_optim': True,
                                          'no_save_rng': True,
                                          'save_interval': 1})
    args = get_args()