Commit 390ddef8 authored by mshoeybi's avatar mshoeybi
Browse files

added multi-batch inference

parent b6555b71
Loading
Loading
Loading
Loading
+16 −7
Original line number Diff line number Diff line
@@ -22,9 +22,20 @@ from megatron.p2p_communication import recv_forward, send_forward
from megatron import get_args


def forward_step(model, tokens, position_ids, attention_mask,
                 set_inference_key_value_memory=False,
                 inference_max_sequence_len=None):
class InferenceParams:
    
    def __init__(self, micro_batch_size_list, max_sequence_len):

        assert isinstance(micro_batch_size_list, list)
        assert max_sequence_len > 0

        self.micro_batch_size_list = micro_batch_size_list
        self.max_sequence_len = max_sequence_len
        self.allocate_key_value_memory = False
        self.micro_batch_size_index = 0


def forward_step(model, tokens, position_ids, attention_mask, inference_params):

    # Hidden size changes when not using recompute, need to tell p2p_communicate
    # functions the correct size
@@ -37,10 +48,8 @@ def forward_step(model, tokens, position_ids, attention_mask,

    # Forward pass through the model.
    model.set_input_tensor(input_tensor)
    output_tensor = model(
        tokens, position_ids, attention_mask,
        set_inference_key_value_memory=set_inference_key_value_memory,
        inference_max_sequence_len=inference_max_sequence_len)
    output_tensor = model(tokens, position_ids, attention_mask,
                          inference_params=inference_params)

    send_forward(output_tensor)

+8 −6
Original line number Diff line number Diff line
@@ -25,7 +25,7 @@ from .communication import (
    copy_from_last_to_first_pipeline_stage,
    broadcast_from_last_pipeline_stage,
    broadcast_from_last_to_first_pipeline_stage)
from .forward_step import forward_step
from .forward_step import forward_step, InferenceParams
from .sampling import sample


@@ -109,6 +109,9 @@ def generate_tokens_probs_and_return_on_first_stage(
    attention_mask, position_ids = _build_attention_mask_and_position_ids(
        tokens)

    # Set inference params
    inference_params = InferenceParams([batch_size], max_sequence_length)
    
    model.eval()
    with torch.no_grad():
        prev_context_length = 0
@@ -117,7 +120,8 @@ def generate_tokens_probs_and_return_on_first_stage(
            # If we are starting from scratch, allocate memory for the entire
            # context, otherwise  set this to false so the memory is not
            # reallocated.
            set_inference_key_value_memory = (prev_context_length == 0)
            inference_params.allocate_key_value_memory = \
                (prev_context_length == 0)

            # Pick the slice that we need to pass through the network.
            tokens2use = tokens[:, prev_context_length:context_length]
@@ -126,10 +130,8 @@ def generate_tokens_probs_and_return_on_first_stage(
                ..., prev_context_length:context_length, :context_length]

            # logits will be meanigful only in the last pipeline stage.
            logits = forward_step(
                model, tokens2use, positions2use, attention_mask2use,
                set_inference_key_value_memory=set_inference_key_value_memory,
                inference_max_sequence_len=max_sequence_length)
            logits = forward_step(model, tokens2use, positions2use,
                                  attention_mask2use, inference_params)

            if mpu.is_pipeline_last_stage():
                # Always the last stage should have an output.
+2 −5
Original line number Diff line number Diff line
@@ -82,16 +82,13 @@ class GPTModel(MegatronModule):
        self.language_model.set_input_tensor(input_tensor)

    def forward(self, input_ids, position_ids, attention_mask, labels=None,
                tokentype_ids=None,
                set_inference_key_value_memory=False,
                inference_max_sequence_len=None):
                tokentype_ids=None, inference_params=None):

        lm_output = self.language_model(
            input_ids,
            position_ids,
            attention_mask,
            set_inference_key_value_memory=set_inference_key_value_memory,
            inference_max_sequence_len=inference_max_sequence_len)
            inference_params=inference_params)

        if self.post_process:
            return post_language_model_processing(
+3 −6
Original line number Diff line number Diff line
@@ -335,8 +335,7 @@ class TransformerLanguageModel(MegatronModule):
    def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
                dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
                enc_dec_attn_mask=None, tokentype_ids=None,
                set_inference_key_value_memory=False,
                inference_max_sequence_len=None,
                inference_params=None,
                pooling_sequence_index=0,
                enc_hidden_states=None, output_enc_hidden=False):

@@ -353,8 +352,7 @@ class TransformerLanguageModel(MegatronModule):
            encoder_output = self.encoder(
                encoder_input,
                enc_attn_mask,
                set_inference_key_value_memory=set_inference_key_value_memory,
                inference_max_sequence_len=inference_max_sequence_len)
                inference_params=inference_params)
        else:
            encoder_output = enc_hidden_states.to(encoder_input.dtype)

@@ -381,8 +379,7 @@ class TransformerLanguageModel(MegatronModule):
            dec_attn_mask,
            encoder_output=encoder_output,
            enc_dec_attn_mask=enc_dec_attn_mask,
            set_inference_key_value_memory=set_inference_key_value_memory,
            inference_max_sequence_len=inference_max_sequence_len)
            inference_params=inference_params)

        if self.add_pooler and self.post_process:
            return decoder_output, encoder_output, pooled_output
+47 −48
Original line number Diff line number Diff line
@@ -180,9 +180,9 @@ class ParallelAttention(MegatronModule):
            skip_bias_add=True)

        # Inference key-value memory
        self.inference_key_memory = None
        self.inference_value_memory = None
        self.inference_current_sequence_len = 0
        self.inference_key_memory_list = None
        self.inference_value_memory_list = None
        self.inference_current_sequence_len_list = None


    def _allocate_memory(self, inference_max_sequence_len, batch_size):
@@ -196,35 +196,32 @@ class ParallelAttention(MegatronModule):
        

    def forward(self, hidden_states, attention_mask,
                encoder_output=None,
                set_inference_key_value_memory=False,
                inference_max_sequence_len=None):
                encoder_output=None, inference_params=None):
        # hidden_states: [sq, b, h]


        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
        if set_inference_key_value_memory:
            assert inference_max_sequence_len and inference_max_sequence_len > 0
            self.inference_key_memory = self._allocate_memory(
                inference_max_sequence_len, hidden_states.size(1))
            self.inference_value_memory = self._allocate_memory(
                inference_max_sequence_len, hidden_states.size(1))
            self.inference_current_sequence_len = 0
        # Some consistency check.
        if inference_max_sequence_len:
            assert self.inference_current_sequence_len < \
                self.inference_key_memory.size(0)
            assert inference_max_sequence_len == \
                self.inference_key_memory.size(0)
        # This is added for safety. In case inference_max_sequence_len
        if inference_params:
            if inference_params.allocate_key_value_memory:
                inf_max_seq_len = inference_params.max_sequence_len
                inf_batch_sizes = inference_params.micro_batch_size_list
                self.inference_key_memory_list = [
                    self._allocate_memory(inf_max_seq_len, inf_batch_size)
                    for inf_batch_size in inf_batch_sizes]
                self.inference_value_memory_list = [
                    self._allocate_memory(inf_max_seq_len, inf_batch_size)
                    for inf_batch_size in inf_batch_sizes]
                self.inference_current_sequence_len_list = [
                    0 for _ in inf_batch_sizes]
        # This is added for safety. In case inference_params
        # is not provided, make sure there is no potential memory left
        # from previous inference.
        if not inference_max_sequence_len:
            self.inference_key_memory = None
            self.inference_value_memory = None
        
        else:
            self.inference_key_memory_list = None
            self.inference_value_memory_list = None
            self.inference_current_sequence_len_list = None

        # =====================
        # Query, Key, and Value
@@ -267,20 +264,27 @@ class ParallelAttention(MegatronModule):
            query_layer = query_layer.view(*new_tensor_shape)


        # ===================================================
        # Adjust key, value, and attention mask for inference
        # ===================================================
        # ==================================
        # Adjust key and value for inference
        # ==================================

        if inference_max_sequence_len:
        if inference_params:
            inf_batch_index = inference_params.micro_batch_size_index
            assert key_layer.size(1) == \
                inference_params.micro_batch_size_list[inf_batch_index]
            # Adjust the range variables.
            start = self.inference_current_sequence_len
            self.inference_current_sequence_len += key_layer.size(0)
            end = self.inference_current_sequence_len
            start = self.inference_current_sequence_len_list[inf_batch_index]
            end = start + key_layer.size(0)
            self.inference_current_sequence_len_list[inf_batch_index] = end
            # Copy key and values.
            self.inference_key_memory[start:end, ...] = key_layer
            self.inference_value_memory[start:end, ...] = value_layer
            key_layer = self.inference_key_memory[:end, ...]
            value_layer = self.inference_value_memory[:end, ...]
            self.inference_key_memory_list[inf_batch_index][start:end, ...] =\
                key_layer
            self.inference_value_memory_list[inf_batch_index][start:end, ...] =\
                value_layer
            key_layer = \
                self.inference_key_memory_list[inf_batch_index][:end, ...]
            value_layer = \
                self.inference_value_memory_list[inf_batch_index][:end, ...]


        # ===================================
@@ -459,10 +463,8 @@ class ParallelTransformerLayer(MegatronModule):
                               output_layer_init_method)

    def forward(self, hidden_states, attention_mask,
                encoder_output=None,
                enc_dec_attn_mask=None,
                set_inference_key_value_memory=False,
                inference_max_sequence_len=None):
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
        # hidden_states: [b, s, h]

        # Layer norm at the beginning of the transformer layer.
@@ -472,8 +474,7 @@ class ParallelTransformerLayer(MegatronModule):
            self.self_attention(
                layernorm_output,
                attention_mask,
                set_inference_key_value_memory=set_inference_key_value_memory,
                inference_max_sequence_len=inference_max_sequence_len)
                inference_params=inference_params)

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
@@ -686,13 +687,11 @@ class ParallelTransformer(MegatronModule):
        self.input_tensor = input_tensor

    def forward(self, hidden_states, attention_mask,
                encoder_output=None,
                enc_dec_attn_mask=None,
                set_inference_key_value_memory=False,
                inference_max_sequence_len=None):
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):

        # Checks.
        if inference_max_sequence_len:
        if inference_params:
            assert self.activations_checkpoint_method is None, \
                'inference does not work with activation checkpointing'

@@ -724,8 +723,8 @@ class ParallelTransformer(MegatronModule):
                    attention_mask,
                    encoder_output=encoder_output,
                    enc_dec_attn_mask=enc_dec_attn_mask,
                    set_inference_key_value_memory=set_inference_key_value_memory,
                    inference_max_sequence_len=inference_max_sequence_len)
                    inference_params=inference_params)


        # Final layer norm.
        if self.post_process: