Commit 538f0e05 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'vijay/ordering_perf_fix' into 'main'

reordering perf fix

See merge request ADLR/megatron-lm!154
parents 42d21122 01dffdb5
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -101,7 +101,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
        # Arguments, iteration, and model.
        state_dict = {}
        state_dict['args'] = args
        state_dict['checkpoint_version'] = 1.0
        state_dict['checkpoint_version'] = 2.0
        state_dict['iteration'] = iteration
        state_dict['model'] = model.state_dict_for_save_checkpoint()

+41 −25
Original line number Diff line number Diff line
@@ -172,17 +172,30 @@ class ParallelSelfAttention(MegatronModule):
            init_method=output_layer_init_method,
            skip_bias_add=True)

    def _transpose_last_dim(self, mixed_layer):
        """[s, b, 3 * hp] -->(view) [s, b, 3, hp] -->(tranpose)
        [s, b, hp, 3] -->(view) [s, b, 3 * hp] """

    def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
        input_shape = mixed_layer.size();
        last_dim = input_shape[-1]
        assert last_dim % 3 == 0, "expected QKV dimension"
        last_dim_split = last_dim // 3
        if num_splits_first:
            """[s, b, num_splits * np * hn] 
            -->(view) [s, b, num_splits, np, hn] 
            -->(tranpose) [s, b, np, num_splits, hn] 
            -->(view) [s, b, np * num_splits * hn] """

            intermediate_shape = input_shape[:-1] +\
            (3, last_dim_split)
                (num_splits, self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)

            mixed_layer = mixed_layer.view(*intermediate_shape)
            mixed_layer = mixed_layer.transpose(-2, -3).contiguous()
        else:
            """[s, b, np * hn * num_splits] 
            -->(view) [s, b, np, hn, num_splits] 
            -->(tranpose) [s, b, np, num_splits, hn] 
            -->(view) [s, b, np * num_splits * hn] """

            intermediate_shape = input_shape[:-1] +\
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head, num_splits)

            mixed_layer = mixed_layer.view(*intermediate_shape)
            mixed_layer = mixed_layer.transpose(-1, -2).contiguous()
        mixed_layer = mixed_layer.view(*input_shape)
@@ -197,25 +210,28 @@ class ParallelSelfAttention(MegatronModule):
        # Query, Key, and Value
        # =====================

        # Attention heads [sq, b, hp] --> [sq, b, hp * 3]
        # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
        mixed_x_layer, _ = self.query_key_value(hidden_states)

        checkpoint_version = get_checkpoint_version()
        if checkpoint_version is not None and \
           checkpoint_version == 0:
            # [sq, b, 3 * hp] --> [sq, b, hp * 3]
            mixed_x_layer = self._transpose_last_dim(mixed_x_layer)

        # [sq, b, hp * 3] --> [sq, b, np, hn, 3]  
        if checkpoint_version is not None:
           if checkpoint_version == 0:
               # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
               mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)
           elif checkpoint_version == 1.0:
               # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
               mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False)

        # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
        new_tensor_shape = mixed_x_layer.size()[:-1] + \
            (self.num_attention_heads_per_partition,
             self.hidden_size_per_attention_head, 3)
             3 * self.hidden_size_per_attention_head)
        mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

        # [sq, b, np, hn, 3] --> 3 [sq, b, np, hn]
        query_layer = mixed_x_layer[:,:,:,:,0]
        key_layer = mixed_x_layer[:,:,:,:,1]
        value_layer = mixed_x_layer[:,:,:,:,2]
        # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
        (query_layer,
         key_layer,
         value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)

        # ==================================
        # Adjust key and value for inference