Commit 98a5b9a0 authored by Jared Casper's avatar Jared Casper
Browse files

Improve comments around layer regex replacement.

parent 7cabbe67
Loading
Loading
Loading
Loading
+6 −3
Original line number Diff line number Diff line
@@ -304,6 +304,7 @@ def main():
    mpu.initialize.set_tensor_model_parallel_rank(0)
    mpu.initialize.set_pipeline_model_parallel_world_size(args.pipeline_model_parallel_size)

    # regex to parse out layer number from param name
    layer_re = re.compile('layers\.([0-9]+)')

    if args.pipeline_model_parallel_size > 1:
@@ -314,7 +315,7 @@ def main():
        for rank in range(args.pipeline_model_parallel_size):
            mpu.initialize.set_pipeline_model_parallel_rank(rank)
            model = get_model(model_type)
            def repl_layer(m):
            def update_layer_num(m):
                # TODO! This assumes no interleaved pipeline execution
                layer = int(m.group(1))
                layer += rank * layers_per_part
@@ -325,8 +326,10 @@ def main():
                    # See comment in MegatronModule.initialize_word_embeddings()
                    src_name = "language_model.embedding.word_embeddings.weight"
                else:
                    src_name = re.sub(layer_re, repl_layer, dst_name)
                print(f" > copying {src_name} to {dst_name} rank {rank}'s model")
                    # Translate destination layer number (0-N for each partition)
                    # to source layer number (single-model layer number)
                    src_name = re.sub(layer_re, update_layer_num, dst_name)
                print(f" > copying {src_name} to {dst_name} in rank {rank}'s model")
                partition_param.data.copy_(merged_params[src_name].data)

            partitions.append(model)