Commit d439b417 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

update comment

parent 9da6e975
Loading
Loading
Loading
Loading
+2 −3
Original line number Diff line number Diff line
@@ -109,9 +109,8 @@ class MegatronModule(torch.nn.Module):
                torch.distributed.all_reduce(self.word_embeddings_weight().data,
                                             group=mpu.get_embedding_group())

            # All-reduce other embeddings as well as necessary. The last stage
            # does not have these other embeddings, so just create placeholder
            # tensors of the right shape with all zeros.
            # Ensure that encoder(first stage) and decoder(split stage) position 
            # embeddings have the same initial parameter values
            # NOTE: We don't currently support T5 with the interleaved schedule.
            if mpu.is_rank_in_position_embedding_group() and \
                    args.pipeline_model_parallel_split_rank is not None: