Commit ade99d61 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

avoid allocation of word embedding for head in T5 pp=2 case

parent 26ea8314
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -85,7 +85,9 @@ class MegatronModule(torch.nn.Module):
        # 3. In the training loop, before an all-reduce between the grads of
        #    the two word_embeddings layers to ensure that every applied weight
        #    update is the same on both stages.
        if mpu.is_pipeline_last_stage():
        if mpu.is_pipeline_last_stage() and \
                (not hasattr(self.language_model, 'embedding') or 
                 self.language_model.embedding is None):
            assert not mpu.is_pipeline_first_stage()
            self._word_embeddings_for_head_key = 'word_embeddings_for_head'
            # set word_embeddings weights to 0 here, then copy first