Commit 5c8238c3 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Fix grad norm computation

parent da1c96e9
Loading
Loading
Loading
Loading
+3 −4
Original line number Diff line number Diff line
@@ -164,9 +164,12 @@ class Embedding(MegatronModule):
    def zero_parameters(self):
        """Zero out all parameters in embedding."""
        self.word_embeddings.weight.data.fill_(0)
        self.word_embeddings.weight.shared = True
        self.position_embeddings.weight.data.fill_(0)
        self.position_embeddings.weight.shared = True
        if self.num_tokentypes > 0:
            self.tokentype_embeddings.weight.data.fill_(0)
            self.tokentype_embeddings.weight.shared = True

    def add_tokentype_embeddings(self, num_tokentypes):
        """Add token-type embedding. This function is provided so we can add
@@ -331,10 +334,6 @@ class TransformerLanguageModel(MegatronModule):
        # Decoder (usually set to False, True if part of an encoder-decoder
        # architecture and in decoder-only stage).
        if self.add_decoder:
            # Temporary assertion until we verify correctness of pipeline parallelism
            # implementation of T5.
            assert args.pipeline_model_parallel_size == 1, \
                'pipeline parallelism is not supported in the presence of decoder'
            self.decoder = ParallelTransformer(
                self.init_method,
                output_layer_init_method,