Commit 9da6e975 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

simplify code

parent ade99d61
Loading
Loading
Loading
Loading
+3 −7
Original line number Diff line number Diff line
@@ -51,8 +51,7 @@ class MegatronModule(torch.nn.Module):


    def word_embeddings_weight(self):
        if hasattr(self.language_model, 'embedding') and \
                self.language_model.embedding is not None:
        if self.pre_process:
            return self.language_model.embedding.word_embeddings.weight
        else:
            if not self.share_word_embeddings:
@@ -86,8 +85,7 @@ class MegatronModule(torch.nn.Module):
        #    the two word_embeddings layers to ensure that every applied weight
        #    update is the same on both stages.
        if mpu.is_pipeline_last_stage() and \
                (not hasattr(self.language_model, 'embedding') or 
                 self.language_model.embedding is None):
                not self.pre_process:
            assert not mpu.is_pipeline_first_stage()
            self._word_embeddings_for_head_key = 'word_embeddings_for_head'
            # set word_embeddings weights to 0 here, then copy first
@@ -101,9 +99,7 @@ class MegatronModule(torch.nn.Module):
        # Zero out initial weights for decoder embedding.
        # NOTE: We don't currently support T5 with the interleaved schedule.
        if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \
                mpu.is_rank_in_embedding_group() and \
                hasattr(self.language_model, 'embedding') and \
                self.language_model.embedding is not None:
                self.pre_process:
            self.language_model.embedding.zero_parameters()

        # Ensure that first and last stages have the same initial parameter