Commit f2bf5a56 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

minor fixes

parent 17843605
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -597,7 +597,8 @@ class ParallelTransformer(MegatronModule):
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
        else:
            # Each stage gets a contiguous set of layers.
            if args.model_type == ModelType.encoder_and_decoder:
            if args.model_type == ModelType.encoder_and_decoder and \
                    mpu.get_pipeline_model_parallel_world_size() > 1:
                pipeline_rank = mpu.get_pipeline_model_parallel_rank()
                if layer_type == LayerType.encoder:
                    offset = pipeline_rank * self.num_layers
+2 −2
Original line number Diff line number Diff line
@@ -30,7 +30,7 @@ _MODEL_PARALLEL_GROUP = None
# Embedding group.
_EMBEDDING_GROUP = None
# Position embedding group.
_POSITION EMBEDDING_GROUP = None
_POSITION_EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None

@@ -208,7 +208,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
        if rank in position_embedding_ranks:
            _POSITION_EMBEDDING_GROUP = group
        if rank in ranks:
            _POSITION_EMBEDDING_GLOBAL_RANKS = embedding_ranks
            _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks


def model_parallel_is_initialized():