Commit 8a6e56b8 authored by mohammad's avatar mohammad Committed by Deepak Narayanan
Browse files

Cleaned up load/save checkpoint printing

parent b81cad66
Loading
Loading
Loading
Loading
+14 −8
Original line number Diff line number Diff line
@@ -110,6 +110,11 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
    # Only rank zero of the data parallel writes to the disk.
    if isinstance(model, torchDDP):
        model = model.module

    if torch.distributed.get_rank() == 0:
        print('saving checkpoint at iteration {:7d} to {}'.format(
            iteration, args.save), flush=True)

    if mpu.get_data_parallel_rank() == 0:

        # Arguments, iteration, and model.
@@ -137,14 +142,14 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):

        # Save.
        checkpoint_name = get_checkpoint_name(args.save, iteration)
        print('global rank {} is saving checkpoint at iteration {:7d} to {}'.
              format(torch.distributed.get_rank(), iteration, checkpoint_name))
        ensure_directory_exists(checkpoint_name)
        torch.save(state_dict, checkpoint_name)
        print('  successfully saved {}'.format(checkpoint_name))

    # Wait so everyone is done (necessary)
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('  successfully saved checkpoint at iteration {:7d} to {}'.format(
            iteration, args.save), flush=True)
    # And update the latest iteration
    if torch.distributed.get_rank() == 0:
        tracker_filename = get_checkpoint_tracker_filename(args.save)
@@ -192,9 +197,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):

    # Checkpoint.
    checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
    if mpu.get_data_parallel_rank() == 0:
        print('global rank {} is loading checkpoint {}'.format(
            torch.distributed.get_rank(), checkpoint_name))
    if torch.distributed.get_rank() == 0:
        print(' loading checkpoint from {} at iteration {}'.format(
            args.load, iteration), flush=True)

    # Load the checkpoint.
    try:
@@ -276,8 +281,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
            sys.exit()

    torch.distributed.barrier()
    if mpu.get_data_parallel_rank() == 0:
        print('  successfully loaded {}'.format(checkpoint_name))
    if torch.distributed.get_rank() == 0:
        print('  successfully loaded checkpoint from {} at iteration {}'.format(
            args.load, iteration), flush=True)

    return iteration