Commit 872e38ea authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'merge_bugfix' into 'main'

Fix bug in merge_mp_partitions for handling recent checkpoints.

See merge request ADLR/megatron-lm!226
parents c601d751 72105ef0
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -240,6 +240,11 @@ def main():
    tokenizer = rebuild_tokenizer(args)
    mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
    for rank in range(args.tensor_model_parallel_size):
        # Reset these since load_checkpoint asserts they are 0, but we are loading
        # multiple checkpoints in the same process and they get set each time
        args.consumed_train_samples = 0
        args.consumed_valid_samples = 0

        mpu.initialize.set_tensor_model_parallel_rank(rank)
        checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
        model_ = get_model(model_type)