Loading megatron/model/transformer.py +16 −23 Original line number Diff line number Diff line Loading @@ -179,10 +179,6 @@ class ParallelAttention(MegatronModule): init_method=output_layer_init_method, skip_bias_add=True) # Inference key-value memory self.inference_key_memory = None self.inference_value_memory = None def _allocate_memory(self, inference_max_sequence_len, batch_size): return torch.empty( Loading @@ -203,19 +199,18 @@ class ParallelAttention(MegatronModule): # Pre-allocate memory for key-values for inference. # ================================================= if inference_params: if inference_params.allocate_key_value_memory: if self.layer_number not in inference_params.key_value_memory_dict: inf_max_seq_len = inference_params.max_sequence_len inf_max_batch_size = inference_params.max_batch_size self.inference_key_memory = self._allocate_memory( inference_key_memory = self._allocate_memory( inf_max_seq_len, inf_max_batch_size) self.inference_value_memory = self._allocate_memory( inference_value_memory = self._allocate_memory( inf_max_seq_len, inf_max_batch_size) # This is added for safety. In case inference_params # is not provided, make sure there is no potential memory left # from previous inference. inference_params.key_value_memory_dict[self.layer_number] = ( inference_key_memory, inference_value_memory) else: self.inference_key_memory = None self.inference_value_memory = None inference_key_memory, inference_value_memory = \ inference_params.key_value_memory_dict[self.layer_number] # ===================== Loading Loading @@ -266,20 +261,18 @@ class ParallelAttention(MegatronModule): if inference_params: batch_start = inference_params.batch_size_offset batch_end = batch_start + key_layer.size(1) assert batch_end <= self.inference_key_memory.size(1) assert batch_end <= inference_key_memory.size(1) sequence_start = inference_params.sequence_len_offset sequence_end = sequence_start + key_layer.size(0) assert sequence_end <= self.inference_key_memory.size(0) assert sequence_end <= inference_key_memory.size(0) # Copy key and values. self.inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer self.inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer key_layer = self.inference_key_memory[ inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer key_layer = inference_key_memory[ :sequence_end, batch_start:batch_end, ...] value_layer = self.inference_value_memory[ value_layer = inference_value_memory[ :sequence_end, batch_start:batch_end, ...] Loading megatron/text_generation/forward_step.py +1 −6 Original line number Diff line number Diff line Loading @@ -40,7 +40,7 @@ class InferenceParams: self.max_batch_size = max_batch_size self.sequence_len_offset = 0 self.batch_size_offset = 0 self.allocate_key_value_memory = True self.key_value_memory_dict = {} Loading Loading @@ -132,11 +132,6 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask, # Send output to the next stage. send_to_next_pipeline_rank(output_tensor) # Make sure we do not allocate context memory anymore. if inference_params.allocate_key_value_memory: inference_params.allocate_key_value_memory = False return output_tensor Loading Loading
megatron/model/transformer.py +16 −23 Original line number Diff line number Diff line Loading @@ -179,10 +179,6 @@ class ParallelAttention(MegatronModule): init_method=output_layer_init_method, skip_bias_add=True) # Inference key-value memory self.inference_key_memory = None self.inference_value_memory = None def _allocate_memory(self, inference_max_sequence_len, batch_size): return torch.empty( Loading @@ -203,19 +199,18 @@ class ParallelAttention(MegatronModule): # Pre-allocate memory for key-values for inference. # ================================================= if inference_params: if inference_params.allocate_key_value_memory: if self.layer_number not in inference_params.key_value_memory_dict: inf_max_seq_len = inference_params.max_sequence_len inf_max_batch_size = inference_params.max_batch_size self.inference_key_memory = self._allocate_memory( inference_key_memory = self._allocate_memory( inf_max_seq_len, inf_max_batch_size) self.inference_value_memory = self._allocate_memory( inference_value_memory = self._allocate_memory( inf_max_seq_len, inf_max_batch_size) # This is added for safety. In case inference_params # is not provided, make sure there is no potential memory left # from previous inference. inference_params.key_value_memory_dict[self.layer_number] = ( inference_key_memory, inference_value_memory) else: self.inference_key_memory = None self.inference_value_memory = None inference_key_memory, inference_value_memory = \ inference_params.key_value_memory_dict[self.layer_number] # ===================== Loading Loading @@ -266,20 +261,18 @@ class ParallelAttention(MegatronModule): if inference_params: batch_start = inference_params.batch_size_offset batch_end = batch_start + key_layer.size(1) assert batch_end <= self.inference_key_memory.size(1) assert batch_end <= inference_key_memory.size(1) sequence_start = inference_params.sequence_len_offset sequence_end = sequence_start + key_layer.size(0) assert sequence_end <= self.inference_key_memory.size(0) assert sequence_end <= inference_key_memory.size(0) # Copy key and values. self.inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer self.inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer key_layer = self.inference_key_memory[ inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer key_layer = inference_key_memory[ :sequence_end, batch_start:batch_end, ...] value_layer = self.inference_value_memory[ value_layer = inference_value_memory[ :sequence_end, batch_start:batch_end, ...] Loading
megatron/text_generation/forward_step.py +1 −6 Original line number Diff line number Diff line Loading @@ -40,7 +40,7 @@ class InferenceParams: self.max_batch_size = max_batch_size self.sequence_len_offset = 0 self.batch_size_offset = 0 self.allocate_key_value_memory = True self.key_value_memory_dict = {} Loading Loading @@ -132,11 +132,6 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask, # Send output to the next stage. send_to_next_pipeline_rank(output_tensor) # Make sure we do not allocate context memory anymore. if inference_params.allocate_key_value_memory: inference_params.allocate_key_value_memory = False return output_tensor Loading