Commit 38a774e9 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 't5_pipeline_parallelism_grad_norm_fix' into 't5_pipeline_parallelism'

Fix grad norm computation

See merge request ADLR/megatron-lm!296
parents 1c3a1e6a 4f4715db
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -164,9 +164,12 @@ class Embedding(MegatronModule):
    def zero_parameters(self):
        """Zero out all parameters in embedding."""
        self.word_embeddings.weight.data.fill_(0)
        self.word_embeddings.weight.shared = True
        self.position_embeddings.weight.data.fill_(0)
        self.position_embeddings.weight.shared = True
        if self.num_tokentypes > 0:
            self.tokentype_embeddings.weight.data.fill_(0)
            self.tokentype_embeddings.weight.shared = True

    def add_tokentype_embeddings(self, num_tokentypes):
        """Add token-type embedding. This function is provided so we can add