Loading megatron/checkpointing.py +9 −16 Original line number Diff line number Diff line Loading @@ -124,26 +124,19 @@ def read_metadata(tracker_filename): assert iteration > 0 or release, 'error parsing metadata file {}'.format( tracker_filename) # Make sure all the ranks read the same meta data. iters_cuda = torch.cuda.LongTensor( torch.distributed.get_world_size()).fill_(0) iters_cuda[torch.distributed.get_rank()] = iteration torch.distributed.all_reduce(iters_cuda) # Get the max iteration retrieved across the ranks. iters_cuda = torch.cuda.LongTensor([iteration]) torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX) max_iter = iters_cuda[0].item() # We should now have all the same iteration. # If not, print a warning and chose the maximum # iteration across all ranks. max_iter = iters_cuda.max().item() min_iter = iters_cuda.min().item() if max_iter == min_iter: print_rank_0('> meta data was loaded successfully ...') else: for rank in range(torch.distributed.get_world_size()): if iters_cuda[rank] != max_iters: print_rank_0('WARNING: on rank {} found iteration {} in the ' if iteration != max_iter: print('WARNING: on rank {} found iteration {} in the ' 'metadata while max iteration across the ranks ' 'is {}, replacing it with max iteration.'.format( rank, iters_cuda[rank], max_iter)) rank, iteration, max_iter), flush=True) return max_iter, release Loading Loading
megatron/checkpointing.py +9 −16 Original line number Diff line number Diff line Loading @@ -124,26 +124,19 @@ def read_metadata(tracker_filename): assert iteration > 0 or release, 'error parsing metadata file {}'.format( tracker_filename) # Make sure all the ranks read the same meta data. iters_cuda = torch.cuda.LongTensor( torch.distributed.get_world_size()).fill_(0) iters_cuda[torch.distributed.get_rank()] = iteration torch.distributed.all_reduce(iters_cuda) # Get the max iteration retrieved across the ranks. iters_cuda = torch.cuda.LongTensor([iteration]) torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX) max_iter = iters_cuda[0].item() # We should now have all the same iteration. # If not, print a warning and chose the maximum # iteration across all ranks. max_iter = iters_cuda.max().item() min_iter = iters_cuda.min().item() if max_iter == min_iter: print_rank_0('> meta data was loaded successfully ...') else: for rank in range(torch.distributed.get_world_size()): if iters_cuda[rank] != max_iters: print_rank_0('WARNING: on rank {} found iteration {} in the ' if iteration != max_iter: print('WARNING: on rank {} found iteration {} in the ' 'metadata while max iteration across the ranks ' 'is {}, replacing it with max iteration.'.format( rank, iters_cuda[rank], max_iter)) rank, iteration, max_iter), flush=True) return max_iter, release Loading