Commit 5ca20cdd authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'iter_read_update' into 'main'

simplified the iteration read check across ranks

See merge request ADLR/megatron-lm!307
parents a8f4edcb ede0a58f
Loading
Loading
Loading
Loading
+9 −16
Original line number Diff line number Diff line
@@ -124,26 +124,19 @@ def read_metadata(tracker_filename):
    assert iteration > 0 or release, 'error parsing metadata file {}'.format(
        tracker_filename)

    # Make sure all the ranks read the same meta data.
    iters_cuda = torch.cuda.LongTensor(
        torch.distributed.get_world_size()).fill_(0)
    iters_cuda[torch.distributed.get_rank()] = iteration
    torch.distributed.all_reduce(iters_cuda)
    # Get the max iteration retrieved across the ranks.
    iters_cuda = torch.cuda.LongTensor([iteration])
    torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
    max_iter = iters_cuda[0].item()

    # We should now have all the same iteration.
    # If not, print a warning and chose the maximum
    # iteration across all ranks.
    max_iter = iters_cuda.max().item()
    min_iter = iters_cuda.min().item()
    if max_iter == min_iter:
        print_rank_0('> meta data was loaded successfully ...')
    else:
        for rank in range(torch.distributed.get_world_size()):
            if iters_cuda[rank] != max_iters:
                print_rank_0('WARNING: on rank {} found iteration {} in the '
    if iteration != max_iter:
        print('WARNING: on rank {} found iteration {} in the '
              'metadata while max iteration across the ranks '
              'is {}, replacing it with max iteration.'.format(
                                 rank, iters_cuda[rank], max_iter))
                  rank, iteration, max_iter), flush=True)
    return max_iter, release