Loading megatron/mpu/layers.py +4 −3 Original line number Diff line number Diff line Loading @@ -110,11 +110,12 @@ class VocabParallelEmbedding(torch.nn.Module): self.scale_grad_by_freq = False self.sparse = False self._weight = None self.model_parallel_size = get_model_parallel_world_size() # Divide the weight matrix along the vocaburaly dimension. self.vocab_start_index, self.vocab_end_index = \ VocabUtility.vocab_range_from_global_vocab_size( self.num_embeddings, get_model_parallel_rank(), get_model_parallel_world_size()) self.model_parallel_size) self.num_embeddings_per_partition = self.vocab_end_index - \ self.vocab_start_index Loading @@ -127,7 +128,7 @@ class VocabParallelEmbedding(torch.nn.Module): self.num_embeddings_per_partition, 0, init_method) def forward(self, input_): if self.num_embeddings_per_partition < self.num_embeddings: if self.model_parallel_size > 1: # Build the mask. input_mask = (input_ < self.vocab_start_index) | \ (input_ >= self.vocab_end_index) Loading @@ -142,7 +143,7 @@ class VocabParallelEmbedding(torch.nn.Module): self.norm_type, self.scale_grad_by_freq, self.sparse) # Mask the output embedding. if self.num_embeddings_per_partition < self.num_embeddings: if self.model_parallel_size > 1: output_parallel[input_mask, :] = 0.0 # Reduce across all the model parallel GPUs. output = reduce_from_model_parallel_region(output_parallel) Loading Loading
megatron/mpu/layers.py +4 −3 Original line number Diff line number Diff line Loading @@ -110,11 +110,12 @@ class VocabParallelEmbedding(torch.nn.Module): self.scale_grad_by_freq = False self.sparse = False self._weight = None self.model_parallel_size = get_model_parallel_world_size() # Divide the weight matrix along the vocaburaly dimension. self.vocab_start_index, self.vocab_end_index = \ VocabUtility.vocab_range_from_global_vocab_size( self.num_embeddings, get_model_parallel_rank(), get_model_parallel_world_size()) self.model_parallel_size) self.num_embeddings_per_partition = self.vocab_end_index - \ self.vocab_start_index Loading @@ -127,7 +128,7 @@ class VocabParallelEmbedding(torch.nn.Module): self.num_embeddings_per_partition, 0, init_method) def forward(self, input_): if self.num_embeddings_per_partition < self.num_embeddings: if self.model_parallel_size > 1: # Build the mask. input_mask = (input_ < self.vocab_start_index) | \ (input_ >= self.vocab_end_index) Loading @@ -142,7 +143,7 @@ class VocabParallelEmbedding(torch.nn.Module): self.norm_type, self.scale_grad_by_freq, self.sparse) # Mask the output embedding. if self.num_embeddings_per_partition < self.num_embeddings: if self.model_parallel_size > 1: output_parallel[input_mask, :] = 0.0 # Reduce across all the model parallel GPUs. output = reduce_from_model_parallel_region(output_parallel) Loading