Loading tools/merge_mp_partitions.py +5 −0 Original line number Diff line number Diff line Loading @@ -240,6 +240,11 @@ def main(): tokenizer = rebuild_tokenizer(args) mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size) for rank in range(args.tensor_model_parallel_size): # Reset these since load_checkpoint asserts they are 0, but we are loading # multiple checkpoints in the same process and they get set each time args.consumed_train_samples = 0 args.consumed_valid_samples = 0 mpu.initialize.set_tensor_model_parallel_rank(rank) checkpoint_name, iteration = get_parallel_checkpoint_name(args.load) model_ = get_model(model_type) Loading Loading
tools/merge_mp_partitions.py +5 −0 Original line number Diff line number Diff line Loading @@ -240,6 +240,11 @@ def main(): tokenizer = rebuild_tokenizer(args) mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size) for rank in range(args.tensor_model_parallel_size): # Reset these since load_checkpoint asserts they are 0, but we are loading # multiple checkpoints in the same process and they get set each time args.consumed_train_samples = 0 args.consumed_valid_samples = 0 mpu.initialize.set_tensor_model_parallel_rank(rank) checkpoint_name, iteration = get_parallel_checkpoint_name(args.load) model_ = get_model(model_type) Loading