Commit 72105ef0 authored by Jared Casper's avatar Jared Casper
Browse files

Fix bug in merge_mp_partitions for handling recent checkpoints.

parent c601d751
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -240,6 +240,11 @@ def main():
    tokenizer = rebuild_tokenizer(args)
    mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
    for rank in range(args.tensor_model_parallel_size):
        # Reset these since load_checkpoint asserts they are 0, but we are loading
        # multiple checkpoints in the same process and they get set each time
        args.consumed_train_samples = 0
        args.consumed_valid_samples = 0

        mpu.initialize.set_tensor_model_parallel_rank(rank)
        checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
        model_ = get_model(model_type)