Commit 6ee3d210 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'leftover_fixes' into 'main'

added across ranks sync for checkpoint iteration load and couple of other fixes

See merge request ADLR/megatron-lm!304
parents 91fcbd68 0722f878
Loading
Loading
Loading
Loading
+42 −15
Original line number Diff line number Diff line
@@ -106,6 +106,47 @@ def get_checkpoint_tracker_filename(checkpoints_path):
    return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')


def read_metadata(tracker_filename):
    # Read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
    iteration = 0
    release = False
    with open(tracker_filename, 'r') as f:
        metastring = f.read().strip()
        try:
            iteration = int(metastring)
        except ValueError:
            release = metastring == 'release'
            if not release:
                print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
                    tracker_filename))
                sys.exit()
    assert iteration > 0 or release, 'error parsing metadata file {}'.format(
        tracker_filename)

    # Make sure all the ranks read the same meta data.
    iters_cuda = torch.cuda.LongTensor(
        torch.distributed.get_world_size()).fill_(0)
    iters_cuda[torch.distributed.get_rank()] = iteration
    torch.distributed.all_reduce(iters_cuda)

    # We should now have all the same iteration.
    # If not, print a warning and chose the maximum
    # iteration across all ranks.
    max_iter = iters_cuda.max().item()
    min_iter = iters_cuda.min().item()
    if max_iter == min_iter:
        print_rank_0('> meta data was loaded successfully ...')
    else:
        for rank in range(torch.distributed.get_world_size()):
            if iters_cuda[rank] != max_iters:
                print_rank_0('WARNING: on rank {} found iteration {} in the '
                             'meta data while max iteration across the ranks '
                             'is {}, replacing it with max iteration.'.format(
                                 rank, iters_cuda[rank], max_iter))
    return max_iter, release


def save_checkpoint(iteration, model, optimizer, lr_scheduler):
    """Save a model checkpoint."""
    args = get_args()
@@ -260,21 +301,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True

    # Otherwise, read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
    iteration = 0
    release = False
    with open(tracker_filename, 'r') as f:
        metastring = f.read().strip()
        try:
            iteration = int(metastring)
        except ValueError:
            release = metastring == 'release'
            if not release:
                print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
                    tracker_filename))
                sys.exit()

    assert iteration > 0 or release, 'error parsing metadata file {}'.format(
        tracker_filename)
    iteration, release = read_metadata(tracker_filename)

    # Checkpoint.
    checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
+4 −5
Original line number Diff line number Diff line
@@ -96,7 +96,7 @@ def pretrain(train_valid_test_dataset_provider,
    # This will be closer to what scheduler will see (outside of
    # image ... launches.
    global _TRAIN_START_TIME
    start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
    start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
    torch.distributed.all_reduce(start_time_tensor,
                                 op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
@@ -787,8 +787,7 @@ def build_train_valid_test_data_iterators(
            'only backward compatiblity support for iteration-based training'
        args.consumed_train_samples = args.iteration * args.global_batch_size
    if args.iteration > 0 and args.consumed_valid_samples == 0:
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
        if args.train_samples is None:
            args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
                args.eval_iters * args.global_batch_size