Commit 35bea728 authored by Boris Fomitchev's avatar Boris Fomitchev
Browse files

Code review comments - changing parallel test condition

parent 84a5997a
Loading
Loading
Loading
Loading
+4 −3
Original line number Diff line number Diff line
@@ -110,11 +110,12 @@ class VocabParallelEmbedding(torch.nn.Module):
        self.scale_grad_by_freq = False
        self.sparse = False
        self._weight = None
        self.model_parallel_size = get_model_parallel_world_size()
        # Divide the weight matrix along the vocaburaly dimension.
        self.vocab_start_index, self.vocab_end_index = \
            VocabUtility.vocab_range_from_global_vocab_size(
                self.num_embeddings, get_model_parallel_rank(),
                get_model_parallel_world_size())
                self.model_parallel_size)
        self.num_embeddings_per_partition = self.vocab_end_index - \
            self.vocab_start_index

@@ -127,7 +128,7 @@ class VocabParallelEmbedding(torch.nn.Module):
            self.num_embeddings_per_partition, 0, init_method)

    def forward(self, input_):
        if self.num_embeddings_per_partition < self.num_embeddings:
        if self.model_parallel_size > 1:
            # Build the mask.
            input_mask = (input_ < self.vocab_start_index) | \
                         (input_ >= self.vocab_end_index)
@@ -142,7 +143,7 @@ class VocabParallelEmbedding(torch.nn.Module):
                                      self.norm_type, self.scale_grad_by_freq,
                                      self.sparse)
        # Mask the output embedding.
        if self.num_embeddings_per_partition < self.num_embeddings:
        if self.model_parallel_size > 1:
            output_parallel[input_mask, :] = 0.0
        # Reduce across all the model parallel GPUs.
        output = reduce_from_model_parallel_region(output_parallel)