Commit ea128da5 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

t5_pipeline_fix

parent 3ae12a47
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -177,7 +177,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
        if len(ranks) > 1:
            embedding_ranks = [ranks[0], ranks[-1]]
            if pipeline_model_parallel_split_rank_ is not None and \
                    pipeline_model_parallel_split_rank_ not in embedding_ranks:
                    ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks:
                embedding_ranks = [ranks[0],
                                   ranks[pipeline_model_parallel_split_rank_],
                                   ranks[-1]]