Commit 2de7ae27 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 't5_pipeline_fixes' into 'main'

Generic fix to T5 pipeline parallelism bug.

See merge request ADLR/megatron-lm!376
parents 9d86ca67 d439b417
Loading
Loading
Loading
Loading
+15 −20
Original line number Diff line number Diff line
@@ -51,8 +51,7 @@ class MegatronModule(torch.nn.Module):


    def word_embeddings_weight(self):
        if not mpu.is_pipeline_last_stage(ignore_virtual=True) or \
                mpu.get_pipeline_model_parallel_world_size() == 1:
        if self.pre_process:
            return self.language_model.embedding.word_embeddings.weight
        else:
            if not self.share_word_embeddings:
@@ -85,7 +84,8 @@ class MegatronModule(torch.nn.Module):
        # 3. In the training loop, before an all-reduce between the grads of
        #    the two word_embeddings layers to ensure that every applied weight
        #    update is the same on both stages.
        if mpu.is_pipeline_last_stage():
        if mpu.is_pipeline_last_stage() and \
                not self.pre_process:
            assert not mpu.is_pipeline_first_stage()
            self._word_embeddings_for_head_key = 'word_embeddings_for_head'
            # set word_embeddings weights to 0 here, then copy first
@@ -99,8 +99,7 @@ class MegatronModule(torch.nn.Module):
        # Zero out initial weights for decoder embedding.
        # NOTE: We don't currently support T5 with the interleaved schedule.
        if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \
                not mpu.is_pipeline_last_stage(ignore_virtual=True) and \
                mpu.is_rank_in_embedding_group():
                self.pre_process:
            self.language_model.embedding.zero_parameters()

        # Ensure that first and last stages have the same initial parameter
@@ -109,21 +108,17 @@ class MegatronModule(torch.nn.Module):
            if mpu.is_rank_in_embedding_group():
                torch.distributed.all_reduce(self.word_embeddings_weight().data,
                                             group=mpu.get_embedding_group())
                # All-reduce other embeddings as well as necessary. The last stage
                # does not have these other embeddings, so just create placeholder
                # tensors of the right shape with all zeros.

            # Ensure that encoder(first stage) and decoder(split stage) position 
            # embeddings have the same initial parameter values
            # NOTE: We don't currently support T5 with the interleaved schedule.
                if args.pipeline_model_parallel_split_rank is not None:
            if mpu.is_rank_in_position_embedding_group() and \
                    args.pipeline_model_parallel_split_rank is not None:
                # TODO: Support tokentype embedding.
                    dimensions = (args.max_position_embeddings, args.hidden_size)
                    if mpu.is_pipeline_last_stage(ignore_virtual=True):
                        position_embeddings = torch.nn.Embedding(*dimensions).cuda()
                        position_embeddings.weight.data.fill_(0)
                    else:
                self.language_model.embedding.cuda()
                position_embeddings = self.language_model.embedding.position_embeddings
                torch.distributed.all_reduce(position_embeddings.weight.data,
                                                 group=mpu.get_embedding_group())
                                             group=mpu.get_position_embedding_group())
        else:
            print("WARNING! Distributed processes aren't initialized, so "
                  "word embeddings in the last layer are not initialized. "