Commit f5345dfa authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'inference_context_memory' into 'main'

made model stateless with respect to inference

See merge request ADLR/megatron-lm!348
parents d33460df efc750b6
Loading
Loading
Loading
Loading
+16 −23
Original line number Diff line number Diff line
@@ -179,10 +179,6 @@ class ParallelAttention(MegatronModule):
            init_method=output_layer_init_method,
            skip_bias_add=True)

        # Inference key-value memory
        self.inference_key_memory = None
        self.inference_value_memory = None


    def _allocate_memory(self, inference_max_sequence_len, batch_size):
        return torch.empty(
@@ -203,19 +199,18 @@ class ParallelAttention(MegatronModule):
        # Pre-allocate memory for key-values for inference.
        # =================================================
        if inference_params:
            if inference_params.allocate_key_value_memory:
            if self.layer_number not in inference_params.key_value_memory_dict:
                inf_max_seq_len = inference_params.max_sequence_len
                inf_max_batch_size = inference_params.max_batch_size
                self.inference_key_memory = self._allocate_memory(
                inference_key_memory = self._allocate_memory(
                    inf_max_seq_len, inf_max_batch_size)
                self.inference_value_memory = self._allocate_memory(
                inference_value_memory = self._allocate_memory(
                    inf_max_seq_len, inf_max_batch_size)
        # This is added for safety. In case inference_params
        # is not provided, make sure there is no potential memory left
        # from previous inference.
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory, inference_value_memory)
            else:
            self.inference_key_memory = None
            self.inference_value_memory = None
                inference_key_memory, inference_value_memory = \
                    inference_params.key_value_memory_dict[self.layer_number]


        # =====================
@@ -266,20 +261,18 @@ class ParallelAttention(MegatronModule):
        if inference_params:
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
            assert batch_end <= self.inference_key_memory.size(1)
            assert batch_end <= inference_key_memory.size(1)
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
            assert sequence_end <= self.inference_key_memory.size(0)
            assert sequence_end <= inference_key_memory.size(0)
            # Copy key and values.
            self.inference_key_memory[sequence_start:sequence_end,
                                      batch_start:batch_end,
                                      ...] = key_layer
            self.inference_value_memory[sequence_start:sequence_end,
                                        batch_start:batch_end,
                                        ...] = value_layer
            key_layer = self.inference_key_memory[
            inference_key_memory[sequence_start:sequence_end,
                                 batch_start:batch_end, ...] = key_layer
            inference_value_memory[sequence_start:sequence_end,
                                   batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[
                :sequence_end, batch_start:batch_end, ...]
            value_layer = self.inference_value_memory[
            value_layer = inference_value_memory[
                :sequence_end, batch_start:batch_end, ...]


+1 −6
Original line number Diff line number Diff line
@@ -40,7 +40,7 @@ class InferenceParams:
        self.max_batch_size = max_batch_size
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
        self.allocate_key_value_memory = True
        self.key_value_memory_dict = {}



@@ -132,11 +132,6 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask,
    # Send output to the next stage.
    send_to_next_pipeline_rank(output_tensor)

    # Make sure we do not allocate context memory anymore.
    if inference_params.allocate_key_value_memory:
        inference_params.allocate_key_value_memory = False


    return output_tensor