Commit 76960d7c authored by Jared Casper's avatar Jared Casper
Browse files

Move rearranging query_key_value and key_value values in old checkpoints to...

Move rearranging query_key_value and key_value values in old checkpoints to when the checkpoint is loaded instead of runtime..
parent c7444380
Loading
Loading
Loading
Loading
+64 −3
Original line number Diff line number Diff line
@@ -23,9 +23,10 @@ import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP

from megatron import mpu, get_args, update_num_microbatches
from megatron import get_args
from megatron import print_rank_0
from megatron import (get_args,
                      mpu,
                      print_rank_0,
                      update_num_microbatches)

_CHECKPOINT_VERSION = None

@@ -163,6 +164,43 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
    if torch.distributed.is_initialized():
        torch.distributed.barrier()

def _transpose_first_dim(t, num_splits, num_splits_first, model):
    input_shape = t.size()
    # We use a self_attention module but the values extracted aren't
    # specific to self attention so should work for cross attention as well
    while hasattr(model, 'module'):
        model = model.module
    attention_module = model.language_model.encoder.layers[0].self_attention
    hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head
    num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition
    if num_splits_first:
        """[num_splits * np * hn, h]
        -->(view) [num_splits, np, hn, h]
        -->(tranpose) [np, num_splits, hn, h]
        -->(view) [np * num_splits * hn, h] """

        intermediate_shape = \
            (num_splits, num_attention_heads_per_partition,
             hidden_size_per_attention_head) + input_shape[1:]

        t = t.view(*intermediate_shape)
        t = t.transpose(0, 1).contiguous()
    else:
        """[np * hn * num_splits, h]
        -->(view) [np, hn, num_splits, h]
        -->(tranpose) [np, num_splits, hn, h]
        -->(view) [np * num_splits * hn, h] """

        intermediate_shape = \
            (num_attention_heads_per_partition,
             hidden_size_per_attention_head, num_splits) +\
             input_shape[1:]

        t = t.view(*intermediate_shape)
        t = t.transpose(1, 2).contiguous()
    t = t.view(*input_shape)

    return t

def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True):
    """Load a model checkpoint and return the iteration.
@@ -261,6 +299,29 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
    # Model.
    model.load_state_dict(state_dict['model'], strict=strict)

    # Fix up query/key/value matrix ordering
    if get_checkpoint_version() < 2.0:
        checkpoint_version = get_checkpoint_version()
        for name, param in model.named_parameters():
            if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
                if checkpoint_version == 0:
                    fixed_param = _transpose_first_dim(param.data, 3, True, model)
                elif checkpoint_version == 1.0:
                    fixed_param = _transpose_first_dim(param.data, 3, False, model)
                else:
                    print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
                    sys.exit()
                param.data.copy_(fixed_param)
            if name.endswith(('.key_value.weight', '.key_value.bias'):
                if checkpoint_version == 0:
                    fixed_param = _transpose_first_dim(param.data, 2, True, model)
                elif checkpoint_version == 1.0:
                    fixed_param = _transpose_first_dim(param.data, 2, False, model)
                else:
                    print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
                    sys.exit()
                param.data.copy_(fixed_param)

    # Optimizer.
    if not release and not args.finetune and not args.no_load_optim:
        try:
+0 −49
Original line number Diff line number Diff line
@@ -21,7 +21,6 @@ import torch.nn.functional as F
from megatron import get_args
from megatron import mpu
from .module import MegatronModule
from megatron.checkpointing import get_checkpoint_version
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import import_layernorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
@@ -185,36 +184,6 @@ class ParallelAttention(MegatronModule):
            init_method=output_layer_init_method,
            skip_bias_add=True)

    def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
        input_shape = mixed_layer.size()
        if num_splits_first:
            """[s, b, num_splits * np * hn] 
            -->(view) [s, b, num_splits, np, hn]
            -->(tranpose) [s, b, np, num_splits, hn]
            -->(view) [s, b, np * num_splits * hn] """

            intermediate_shape = input_shape[:-1] +\
                (num_splits, self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)

            mixed_layer = mixed_layer.view(*intermediate_shape)
            mixed_layer = mixed_layer.transpose(-2, -3).contiguous()
        else:
            """[s, b, np * hn * num_splits] 
            -->(view) [s, b, np, hn, num_splits]
            -->(tranpose) [s, b, np, num_splits, hn]
            -->(view) [s, b, np * num_splits * hn] """

            intermediate_shape = input_shape[:-1] +\
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head, num_splits)

            mixed_layer = mixed_layer.view(*intermediate_shape)
            mixed_layer = mixed_layer.transpose(-1, -2).contiguous()
        mixed_layer = mixed_layer.view(*input_shape)

        return mixed_layer

    def forward(self, hidden_states, attention_mask, layer_past=None,
                get_key_value=False, encoder_output=None):
        # hidden_states: [sq, b, h]
@@ -227,15 +196,6 @@ class ParallelAttention(MegatronModule):
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            checkpoint_version = get_checkpoint_version()
            if checkpoint_version is not None:
                if checkpoint_version == 0:
                    # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
                    mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)
                elif checkpoint_version == 1.0:
                    # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
                    mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
@@ -250,15 +210,6 @@ class ParallelAttention(MegatronModule):
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            checkpoint_version = get_checkpoint_version()
            if checkpoint_version is not None:
                if checkpoint_version == 0:
                    # [s, b, (2 * np * hn)] --> [s, b, (np * 2 * hn)]
                    mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, True)
                elif checkpoint_version == 1.0:
                    # [s, b, (np * hn * 2)] --> [s, b, (np * 2 * hn)]
                    mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, False)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
+0 −4
Original line number Diff line number Diff line
@@ -243,10 +243,6 @@ def main():
        print(f'> loading {checkpoint_name} ...')
        load_checkpoint(model_, None, None)
        print(f'> checkpoint version {get_checkpoint_version()}')
        if get_checkpoint_version() < 2.0:
            # Need to deal with the qkv matrix order of old versions
            print("Checkpoints less than version 2.0 are not currently supported.")
            exit()
        partitions.append(model_)

    # Parameter generators so we can loop through them semiltaneouly.