Loading megatron/model/transformer.py +27 −27 Original line number Diff line number Diff line Loading @@ -191,28 +191,28 @@ class ParallelSelfAttention(MegatronModule): def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False): # hidden_states: [s, b, h] # hidden_states: [sq, b, h] # ===================== # Query, Key, and Value # ===================== # Attention heads [s, b, hp] --> [s, b, hp * 3] # Attention heads [sq, b, hp] --> [sq, b, hp * 3] mixed_x_layer, _ = self.query_key_value(hidden_states) checkpoint_version = get_checkpoint_version() if checkpoint_version is not None and \ checkpoint_version == 0: # [s, b, 3 * hp] --> [s, b, hp * 3] # [sq, b, 3 * hp] --> [sq, b, hp * 3] mixed_x_layer = self._transpose_last_dim(mixed_x_layer) # [s, b, hp * 3] --> [s, b, np, hn, 3] # [sq, b, hp * 3] --> [sq, b, np, hn, 3] new_tensor_shape = mixed_x_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, 3) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [s, b, np, hn, 3] --> 3 [s, b, np, hn] # [sq, b, np, hn, 3] --> 3 [sq, b, np, hn] query_layer = mixed_x_layer[:,:,:,:,0] key_layer = mixed_x_layer[:,:,:,:,1] value_layer = mixed_x_layer[:,:,:,:,2] Loading @@ -235,19 +235,19 @@ class ParallelSelfAttention(MegatronModule): # Raw attention scores. [b, np, s, s] # =================================== # [b, np, s, s] # [b, np, sq, sk] output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) # [s, b, np, hn] -> [s, b * np, hn] # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) # preallocting result tensor: [b * np, s, s] # preallocting result tensor: [b * np, sq, sk] matmul_result = torch.empty( output_size[0]*output_size[1], output_size[2], Loading @@ -255,18 +255,18 @@ class ParallelSelfAttention(MegatronModule): dtype=query_layer.dtype, device=torch.cuda.current_device()) # Raw attention scores. [b * np, s, s] # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm(matmul_result, query_layer.transpose(0, 1), # [b * np, s, hn] key_layer.transpose(0,1).transpose(1, 2), #[b * np, hn, s] query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0,1).transpose(1, 2), #[b * np, hn, sk] beta=0.0, alpha=(1.0/self.norm_factor)) # change view to [b, np, s, s] # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) # ================================================== # Update attention mask for inference. [b, np, s, s] # Update attention mask for inference. [b, np, sq, sk] # ================================================== if get_key_value: Loading @@ -287,7 +287,7 @@ class ParallelSelfAttention(MegatronModule): # Attention probs and dropout # =========================== # attention scores and attention mask [b, np, s, s] # attention scores and attention mask [b, np, sq, sk] attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) Loading @@ -298,43 +298,43 @@ class ParallelSelfAttention(MegatronModule): # ========================= # Context layer. [s, b, hp] # Context layer. [sq, b, hp] # ========================= # value_layer -> context layer. # [s, b, np, hn] --> [b, np, s, hn] # [sk, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, s, hn] # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(1), value_layer.size(2), value_layer.size(0), query_layer.size(0), value_layer.size(3)) # change view [s, b * np, hn] value_layer = value_layer.view(output_size[2], # change view [sk, b * np, hn] value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) # change view [b * np, s, s] # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) # matmul: [b * np, s, hn] # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0,1)) # change view [b, np, s, hn] # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) # [b, np, s, hn] --> [s, b, np, hn] # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [s, b, np, hn] --> [s, b, hp] # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = context_layer.size()[:-2] + \ (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) # ================= # Output. [s, b, h] # Output. [sq, b, h] # ================= output, bias = self.dense(context_layer) Loading Loading
megatron/model/transformer.py +27 −27 Original line number Diff line number Diff line Loading @@ -191,28 +191,28 @@ class ParallelSelfAttention(MegatronModule): def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False): # hidden_states: [s, b, h] # hidden_states: [sq, b, h] # ===================== # Query, Key, and Value # ===================== # Attention heads [s, b, hp] --> [s, b, hp * 3] # Attention heads [sq, b, hp] --> [sq, b, hp * 3] mixed_x_layer, _ = self.query_key_value(hidden_states) checkpoint_version = get_checkpoint_version() if checkpoint_version is not None and \ checkpoint_version == 0: # [s, b, 3 * hp] --> [s, b, hp * 3] # [sq, b, 3 * hp] --> [sq, b, hp * 3] mixed_x_layer = self._transpose_last_dim(mixed_x_layer) # [s, b, hp * 3] --> [s, b, np, hn, 3] # [sq, b, hp * 3] --> [sq, b, np, hn, 3] new_tensor_shape = mixed_x_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, 3) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [s, b, np, hn, 3] --> 3 [s, b, np, hn] # [sq, b, np, hn, 3] --> 3 [sq, b, np, hn] query_layer = mixed_x_layer[:,:,:,:,0] key_layer = mixed_x_layer[:,:,:,:,1] value_layer = mixed_x_layer[:,:,:,:,2] Loading @@ -235,19 +235,19 @@ class ParallelSelfAttention(MegatronModule): # Raw attention scores. [b, np, s, s] # =================================== # [b, np, s, s] # [b, np, sq, sk] output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) # [s, b, np, hn] -> [s, b * np, hn] # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) # preallocting result tensor: [b * np, s, s] # preallocting result tensor: [b * np, sq, sk] matmul_result = torch.empty( output_size[0]*output_size[1], output_size[2], Loading @@ -255,18 +255,18 @@ class ParallelSelfAttention(MegatronModule): dtype=query_layer.dtype, device=torch.cuda.current_device()) # Raw attention scores. [b * np, s, s] # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm(matmul_result, query_layer.transpose(0, 1), # [b * np, s, hn] key_layer.transpose(0,1).transpose(1, 2), #[b * np, hn, s] query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0,1).transpose(1, 2), #[b * np, hn, sk] beta=0.0, alpha=(1.0/self.norm_factor)) # change view to [b, np, s, s] # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) # ================================================== # Update attention mask for inference. [b, np, s, s] # Update attention mask for inference. [b, np, sq, sk] # ================================================== if get_key_value: Loading @@ -287,7 +287,7 @@ class ParallelSelfAttention(MegatronModule): # Attention probs and dropout # =========================== # attention scores and attention mask [b, np, s, s] # attention scores and attention mask [b, np, sq, sk] attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) Loading @@ -298,43 +298,43 @@ class ParallelSelfAttention(MegatronModule): # ========================= # Context layer. [s, b, hp] # Context layer. [sq, b, hp] # ========================= # value_layer -> context layer. # [s, b, np, hn] --> [b, np, s, hn] # [sk, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, s, hn] # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(1), value_layer.size(2), value_layer.size(0), query_layer.size(0), value_layer.size(3)) # change view [s, b * np, hn] value_layer = value_layer.view(output_size[2], # change view [sk, b * np, hn] value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) # change view [b * np, s, s] # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) # matmul: [b * np, s, hn] # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0,1)) # change view [b, np, s, hn] # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) # [b, np, s, hn] --> [s, b, np, hn] # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [s, b, np, hn] --> [s, b, hp] # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = context_layer.size()[:-2] + \ (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) # ================= # Output. [s, b, h] # Output. [sq, b, h] # ================= output, bias = self.dense(context_layer) Loading