Commit 54282071 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'vijay/diff_query_key_lengths' into 'main'

support for different query key sequence lengths

See merge request ADLR/megatron-lm!151
parents 56e16cba d763f54a
Loading
Loading
Loading
Loading
+27 −27
Original line number Diff line number Diff line
@@ -191,28 +191,28 @@ class ParallelSelfAttention(MegatronModule):

    def forward(self, hidden_states, attention_mask, layer_past=None,
                get_key_value=False):
        # hidden_states: [s, b, h]
        # hidden_states: [sq, b, h]

        # =====================
        # Query, Key, and Value
        # =====================

        # Attention heads [s, b, hp] --> [s, b, hp * 3]
        # Attention heads [sq, b, hp] --> [sq, b, hp * 3]
        mixed_x_layer, _ = self.query_key_value(hidden_states)
 
        checkpoint_version = get_checkpoint_version()
        if checkpoint_version is not None and \
           checkpoint_version == 0:
            # [s, b, 3 * hp] --> [s, b, hp * 3]
            # [sq, b, 3 * hp] --> [sq, b, hp * 3]
            mixed_x_layer = self._transpose_last_dim(mixed_x_layer)

        # [s, b, hp * 3] --> [s, b, np, hn, 3]  
        # [sq, b, hp * 3] --> [sq, b, np, hn, 3]  
        new_tensor_shape = mixed_x_layer.size()[:-1] + \
            (self.num_attention_heads_per_partition,
             self.hidden_size_per_attention_head, 3)
        mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

        # [s, b, np, hn, 3] --> 3 [s, b, np, hn]
        # [sq, b, np, hn, 3] --> 3 [sq, b, np, hn]
        query_layer = mixed_x_layer[:,:,:,:,0]
        key_layer = mixed_x_layer[:,:,:,:,1]
        value_layer = mixed_x_layer[:,:,:,:,2]
@@ -235,19 +235,19 @@ class ParallelSelfAttention(MegatronModule):
        # Raw attention scores. [b, np, s, s]
        # ===================================
        
        # [b, np, s, s]
        # [b, np, sq, sk]
        output_size = (query_layer.size(1), 
                       query_layer.size(2), 
                       query_layer.size(0), 
                       key_layer.size(0))
        
        # [s, b, np, hn] -> [s, b * np, hn]
        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, s, s]
        # preallocting result tensor: [b * np, sq, sk]
        matmul_result = torch.empty(
            output_size[0]*output_size[1], 
            output_size[2], 
@@ -255,18 +255,18 @@ class ParallelSelfAttention(MegatronModule):
            dtype=query_layer.dtype, 
            device=torch.cuda.current_device())

        # Raw attention scores. [b * np, s, s]
        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(matmul_result, 
            query_layer.transpose(0, 1),   # [b * np, s, hn]
            key_layer.transpose(0,1).transpose(1, 2),  #[b * np, hn, s]
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
            key_layer.transpose(0,1).transpose(1, 2),  #[b * np, hn, sk]
            beta=0.0, alpha=(1.0/self.norm_factor))

        # change view to [b, np, s, s]
        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)


        # ==================================================
        # Update attention mask for inference. [b, np, s, s]
        # Update attention mask for inference. [b, np, sq, sk]
        # ==================================================

        if get_key_value:
@@ -287,7 +287,7 @@ class ParallelSelfAttention(MegatronModule):
        # Attention probs and dropout
        # ===========================

        # attention scores and attention mask [b, np, s, s]
        # attention scores and attention mask [b, np, sq, sk]
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)

@@ -298,43 +298,43 @@ class ParallelSelfAttention(MegatronModule):


        # =========================
        # Context layer. [s, b, hp]
        # Context layer. [sq, b, hp]
        # =========================

        # value_layer -> context layer.
        # [s, b, np, hn] --> [b, np, s, hn]
        # [sk, b, np, hn] --> [b, np, sq, hn]

        # context layer shape: [b, np, s, hn]
        # context layer shape: [b, np, sq, hn]
        output_size = (value_layer.size(1), 
                       value_layer.size(2), 
                       value_layer.size(0), 
                       query_layer.size(0), 
                       value_layer.size(3)) 

        # change view [s, b * np, hn] 
        value_layer = value_layer.view(output_size[2],
        # change view [sk, b * np, hn] 
        value_layer = value_layer.view(value_layer.size(0),
                                       output_size[0] * output_size[1], -1)
        
        # change view [b * np, s, s]
        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)
        
        # matmul: [b * np, s, hn]
        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0,1))

        # change view [b, np, s, hn]
        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

        # [b, np, s, hn] --> [s, b, np, hn]
        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [s, b, np, hn] --> [s, b, hp]
        # [sq, b, np, hn] --> [sq, b, hp]
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)


        # =================
        # Output. [s, b, h]
        # Output. [sq, b, h]
        # =================

        output, bias = self.dense(context_layer)