Loading megatron/inference/forward_step.py +126 −122 Original line number Diff line number Diff line Loading @@ -15,8 +15,6 @@ """Forward step utilities.""" from abc import ABC from abc import abstractmethod from collections.abc import Iterable import torch Loading @@ -24,44 +22,27 @@ import torch from megatron import ( get_args, mpu) from megatron.p2p_communication import ( recv_forward, send_forward) def forward_step_provider(model, batch_size, micro_batch_size, max_sequence_len): args = get_args() if args.pipeline_model_parallel_size == 1 or micro_batch_size >= batch_size: return NoPipeliningForwardStep(model, batch_size, max_sequence_len) return SimplePipeliningForwardStep(model, batch_size, micro_batch_size, max_sequence_len) class InferenceParams: def __init__(self, micro_batch_size_list, max_sequence_len): assert isinstance(micro_batch_size_list, list) assert max_sequence_len > 0 def __init__(self, max_batch_size, max_sequence_len): self.micro_batch_size_list = micro_batch_size_list self.max_sequence_len = max_sequence_len self.max_batch_size = max_batch_size self.sequence_len_offset = 0 self.batch_size_offset = 0 self.allocate_key_value_memory = True self.micro_batch_index = 0 class ForwardStepBase(ABC): def __init__(self, model): class ForwardStep: def __init__(self, model, max_batch_size, max_sequence_len): # Make sure model is in eval mode. if isinstance(model, Iterable): for this_model in model: this_model.eval() Loading @@ -69,125 +50,148 @@ class ForwardStepBase(ABC): model.eval() self.model = model @abstractmethod def __call__(self, tokens, position_ids, attention_mask): pass self.constant = 512 # Initialize inference parameters. self.inference_params = InferenceParams(max_batch_size, max_sequence_len) class SimplePipeliningForwardStep(ForwardStepBase): def __call__(self, tokens, position_ids, attention_mask): if tokens.size(0) * tokens.size(1) >= self.constant: micro_batch_size = max(1, self.constant // tokens.size(1)) return _with_pipelining_forward_step(self.model, tokens, position_ids, attention_mask, self.inference_params, micro_batch_size) else: return _no_pipelining_forward_step(self.model, tokens, position_ids, attention_mask, self.inference_params) def __init__(self, model, batch_size, micro_batch_size, max_sequence_len): super().__init__(model) self.batch_size = batch_size # Divide the batch dimension into micro batches. self.num_micro_batches, last_chunk = divmod(batch_size, micro_batch_size) self.micro_batch_size_list = [] self.batch_dim_start_index = [0] for i in range(self.num_micro_batches): self.micro_batch_size_list.append(micro_batch_size) self.batch_dim_start_index.append((i + 1) * micro_batch_size) if last_chunk > 0: self.num_micro_batches += 1 self.micro_batch_size_list.append(last_chunk) self.batch_dim_start_index.append(batch_size) self.inference_params = InferenceParams(self.micro_batch_size_list, max_sequence_len) def _get_recv_buffer_dtype(args): """Receive happens between the layers.""" if args.fp32_residual_connection: return torch.float return args.params_dtype def __call__(self, tokens, position_ids, attention_mask): # Need to tell p2p_communicate functions the correct size. def _allocate_recv_buffer(batch_size, sequence_length): """Receive happens between the layers with size [s, b, h].""" if mpu.is_pipeline_first_stage(): return None args = get_args() orig_seq_length = args.seq_length args.seq_length = tokens.size(1) assert args.seq_length <= self.inference_params.max_sequence_len # Preallocate memory for output logits. logits = None if mpu.is_pipeline_last_stage(): logits = torch.empty(tokens.size(0), tokens.size(1), args.padded_vocab_size, dtype=torch.float32, recv_size = (sequence_length, batch_size, args.hidden_size) return torch.empty(recv_size, dtype=_get_recv_buffer_dtype(args), device=torch.cuda.current_device()) # Pileline using micro batches. for micro_batch_index in range(self.num_micro_batches): # Set micro-batch size and index. self.inference_params.micro_batch_index = micro_batch_index args.micro_batch_size = self.micro_batch_size_list[ micro_batch_index] # Slice among the batch dimenion. start = self.batch_dim_start_index[micro_batch_index] end = self.batch_dim_start_index[micro_batch_index + 1] tokens2use = tokens[start:end, ...] position_ids2use = position_ids[start:end, ...] def _forward_step_helper(model, tokens, position_ids, attention_mask, inference_params, recv_buffer=None): """Single forward step. Update the allocate memory flag so only the first time the memory is allocated.""" batch_size = tokens.size(0) sequence_length = tokens.size(1) if recv_buffer is None: recv_buffer = _allocate_recv_buffer(batch_size, sequence_length) # Receive from previous stage. input_tensor = recv_forward() if not mpu.is_pipeline_first_stage(): torch.distributed.recv(recv_buffer, src=mpu.get_pipeline_model_parallel_prev_rank()) # Forward pass through the model. self.model.set_input_tensor(input_tensor) output_tensor = self.model(tokens2use, position_ids2use, attention_mask, inference_params=self.inference_params) model.set_input_tensor(recv_buffer) output_tensor = model(tokens, position_ids, attention_mask, inference_params=inference_params) # Send output to the next stage. send_forward(output_tensor) if not mpu.is_pipeline_last_stage(): torch.distributed.send(output_tensor, mpu.get_pipeline_model_parallel_next_rank()) # Reset the sequence lenght to whatwever it was before. # Make sure we do not allocate context memory anymore. if self.inference_params.allocate_key_value_memory: self.inference_params.allocate_key_value_memory = False if inference_params.allocate_key_value_memory: inference_params.allocate_key_value_memory = False if mpu.is_pipeline_last_stage(): logits[start:end, ...] = output_tensor # Adjust the sequence length back to whatever it was before. args.seq_length = orig_seq_length return output_tensor return logits def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask, inference_params, recv_buffer=None): # Run a simple forward pass. output_tensor = _forward_step_helper(model, tokens, position_ids, attention_mask, inference_params, recv_buffer=recv_buffer) # Update the sequence length offset. inference_params.sequence_len_offset += tokens.size(1) logits = None if mpu.is_pipeline_last_stage(): logits = output_tensor class NoPipeliningForwardStep(ForwardStepBase): return logits def __init__(self, model, batch_size, max_sequence_len): super().__init__(model) self.inference_params = InferenceParams([batch_size], max_sequence_len) def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask, inference_params, micro_batch_size): sequence_length = tokens.size(1) batch_size = tokens.size(0) def __call__(self, tokens, position_ids, attention_mask): # Divide the batch dimension into micro batches. num_micro_batches, last_chunk = divmod(batch_size, micro_batch_size) if last_chunk > 0: num_micro_batches += 1 # Need to tell p2p_communicate functions the correct size. # Preallocate memory for output logits. logits = None if mpu.is_pipeline_last_stage(): args = get_args() orig_seq_length = args.seq_length args.seq_length = tokens.shape[1] assert args.seq_length <= self.inference_params.max_sequence_len args.micro_batch_size = tokens.shape[0] assert self.inference_params.micro_batch_size_list[0] == tokens.shape[0] assert self.inference_params.micro_batch_index == 0 logits = torch.empty( (batch_size, sequence_length, args.padded_vocab_size), dtype=torch.float32, device=torch.cuda.current_device()) # Receive from previous stage. input_tensor = recv_forward() # Preallocate recv buffer. recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length) # Forward pass through the model. self.model.set_input_tensor(input_tensor) output_tensor = self.model(tokens, position_ids, attention_mask, inference_params=self.inference_params) for micro_batch_index in range(num_micro_batches): # Slice among the batch dimenion. start = micro_batch_index * micro_batch_size end = min(start + micro_batch_size, batch_size) this_micro_batch_size = end - start tokens2use = tokens[start:end, ...] position_ids2use = position_ids[start:end, ...] # Send output to the next stage. send_forward(output_tensor) # Run a simple forward pass. if this_micro_batch_size != micro_batch_size: recv_buffer = None output = _forward_step_helper(model, tokens2use, position_ids2use, attention_mask, inference_params, recv_buffer=recv_buffer) # Reset the sequence lenght to whatwever it was before. args.seq_length = orig_seq_length # Make sure we do not allocate context memory anymore. if self.inference_params.allocate_key_value_memory: self.inference_params.allocate_key_value_memory = False # Adjust the batch size offset to account for the micro-batch. inference_params.batch_size_offset += this_micro_batch_size return output_tensor # Copy logits. if mpu.is_pipeline_last_stage(): logits[start:end, ...] = output # Once we are done with all the micro-batches, we can # adjust the sequence length offset. inference_params.sequence_len_offset += sequence_length # and reset the batch size offset inference_params.batch_size_offset = 0 return logits megatron/inference/generation.py +4 −5 Original line number Diff line number Diff line Loading @@ -24,7 +24,7 @@ from .communication import ( copy_from_last_to_first_pipeline_stage, broadcast_from_last_pipeline_stage, broadcast_from_last_to_first_pipeline_stage) from .forward_step import forward_step_provider from .forward_step import ForwardStep from .sampling import sample Loading Loading @@ -66,8 +66,7 @@ def generate_tokens_probs_and_return_on_first_stage( max_sequence_length = min(max_sequence_length, args.max_position_embeddings) # forward step. forward_step = forward_step_provider(model, batch_size, 4, max_sequence_length) forward_step = ForwardStep(model, batch_size, max_sequence_length) # Added termination_id to support the case that we want to terminate the # generation once that id is generated. Loading Loading @@ -190,8 +189,8 @@ def generate_tokens_probs_and_return_on_first_stage( done = torch.all(is_generation_done) done = broadcast_from_last_pipeline_stage(1, torch.uint8, tensor=done) if done: break #if done: # break # =================================================== # Update the length of based on max generated length. Loading megatron/model/transformer.py +25 −31 Original line number Diff line number Diff line Loading @@ -180,9 +180,8 @@ class ParallelAttention(MegatronModule): skip_bias_add=True) # Inference key-value memory self.inference_key_memory_list = None self.inference_value_memory_list = None self.inference_current_sequence_len_list = None self.inference_key_memory = None self.inference_value_memory = None def _allocate_memory(self, inference_max_sequence_len, batch_size): Loading @@ -206,22 +205,17 @@ class ParallelAttention(MegatronModule): if inference_params: if inference_params.allocate_key_value_memory: inf_max_seq_len = inference_params.max_sequence_len inf_batch_sizes = inference_params.micro_batch_size_list self.inference_key_memory_list = [ self._allocate_memory(inf_max_seq_len, inf_batch_size) for inf_batch_size in inf_batch_sizes] self.inference_value_memory_list = [ self._allocate_memory(inf_max_seq_len, inf_batch_size) for inf_batch_size in inf_batch_sizes] self.inference_current_sequence_len_list = [ 0 for _ in inf_batch_sizes] inf_max_batch_size = inference_params.max_batch_size self.inference_key_memory = self._allocate_memory( inf_max_seq_len, inf_max_batch_size) self.inference_value_memory = self._allocate_memory( inf_max_seq_len, inf_max_batch_size) # This is added for safety. In case inference_params # is not provided, make sure there is no potential memory left # from previous inference. else: self.inference_key_memory_list = None self.inference_value_memory_list = None self.inference_current_sequence_len_list = None self.inference_value_memory = None self.inference_current_sequence_len = None # ===================== # Query, Key, and Value Loading Loading @@ -269,23 +263,23 @@ class ParallelAttention(MegatronModule): # ================================== if inference_params: inf_batch_index = inference_params.micro_batch_index assert key_layer.size(1) == \ inference_params.micro_batch_size_list[inf_batch_index] # Adjust the range variables. start = self.inference_current_sequence_len_list[inf_batch_index] end = start + key_layer.size(0) assert end <= inference_params.max_sequence_len self.inference_current_sequence_len_list[inf_batch_index] = end batch_start = inference_params.batch_size_offset batch_end = batch_start + key_layer.size(1) assert batch_end <= self.inference_key_memory.size(1) sequence_start = inference_params.sequence_len_offset sequence_end = sequence_start + key_layer.size(0) assert sequence_end <= self.inference_key_memory.size(0) # Copy key and values. self.inference_key_memory_list[inf_batch_index][start:end, ...] \ = key_layer self.inference_value_memory_list[inf_batch_index][start:end, ...] \ = value_layer key_layer = \ self.inference_key_memory_list[inf_batch_index][:end, ...] value_layer = \ self.inference_value_memory_list[inf_batch_index][:end, ...] self.inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer self.inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer key_layer = self.inference_key_memory[ :sequence_end, batch_start:batch_end, ...] value_layer = self.inference_value_memory[ :sequence_end, batch_start:batch_end, ...] # =================================== Loading Loading
megatron/inference/forward_step.py +126 −122 Original line number Diff line number Diff line Loading @@ -15,8 +15,6 @@ """Forward step utilities.""" from abc import ABC from abc import abstractmethod from collections.abc import Iterable import torch Loading @@ -24,44 +22,27 @@ import torch from megatron import ( get_args, mpu) from megatron.p2p_communication import ( recv_forward, send_forward) def forward_step_provider(model, batch_size, micro_batch_size, max_sequence_len): args = get_args() if args.pipeline_model_parallel_size == 1 or micro_batch_size >= batch_size: return NoPipeliningForwardStep(model, batch_size, max_sequence_len) return SimplePipeliningForwardStep(model, batch_size, micro_batch_size, max_sequence_len) class InferenceParams: def __init__(self, micro_batch_size_list, max_sequence_len): assert isinstance(micro_batch_size_list, list) assert max_sequence_len > 0 def __init__(self, max_batch_size, max_sequence_len): self.micro_batch_size_list = micro_batch_size_list self.max_sequence_len = max_sequence_len self.max_batch_size = max_batch_size self.sequence_len_offset = 0 self.batch_size_offset = 0 self.allocate_key_value_memory = True self.micro_batch_index = 0 class ForwardStepBase(ABC): def __init__(self, model): class ForwardStep: def __init__(self, model, max_batch_size, max_sequence_len): # Make sure model is in eval mode. if isinstance(model, Iterable): for this_model in model: this_model.eval() Loading @@ -69,125 +50,148 @@ class ForwardStepBase(ABC): model.eval() self.model = model @abstractmethod def __call__(self, tokens, position_ids, attention_mask): pass self.constant = 512 # Initialize inference parameters. self.inference_params = InferenceParams(max_batch_size, max_sequence_len) class SimplePipeliningForwardStep(ForwardStepBase): def __call__(self, tokens, position_ids, attention_mask): if tokens.size(0) * tokens.size(1) >= self.constant: micro_batch_size = max(1, self.constant // tokens.size(1)) return _with_pipelining_forward_step(self.model, tokens, position_ids, attention_mask, self.inference_params, micro_batch_size) else: return _no_pipelining_forward_step(self.model, tokens, position_ids, attention_mask, self.inference_params) def __init__(self, model, batch_size, micro_batch_size, max_sequence_len): super().__init__(model) self.batch_size = batch_size # Divide the batch dimension into micro batches. self.num_micro_batches, last_chunk = divmod(batch_size, micro_batch_size) self.micro_batch_size_list = [] self.batch_dim_start_index = [0] for i in range(self.num_micro_batches): self.micro_batch_size_list.append(micro_batch_size) self.batch_dim_start_index.append((i + 1) * micro_batch_size) if last_chunk > 0: self.num_micro_batches += 1 self.micro_batch_size_list.append(last_chunk) self.batch_dim_start_index.append(batch_size) self.inference_params = InferenceParams(self.micro_batch_size_list, max_sequence_len) def _get_recv_buffer_dtype(args): """Receive happens between the layers.""" if args.fp32_residual_connection: return torch.float return args.params_dtype def __call__(self, tokens, position_ids, attention_mask): # Need to tell p2p_communicate functions the correct size. def _allocate_recv_buffer(batch_size, sequence_length): """Receive happens between the layers with size [s, b, h].""" if mpu.is_pipeline_first_stage(): return None args = get_args() orig_seq_length = args.seq_length args.seq_length = tokens.size(1) assert args.seq_length <= self.inference_params.max_sequence_len # Preallocate memory for output logits. logits = None if mpu.is_pipeline_last_stage(): logits = torch.empty(tokens.size(0), tokens.size(1), args.padded_vocab_size, dtype=torch.float32, recv_size = (sequence_length, batch_size, args.hidden_size) return torch.empty(recv_size, dtype=_get_recv_buffer_dtype(args), device=torch.cuda.current_device()) # Pileline using micro batches. for micro_batch_index in range(self.num_micro_batches): # Set micro-batch size and index. self.inference_params.micro_batch_index = micro_batch_index args.micro_batch_size = self.micro_batch_size_list[ micro_batch_index] # Slice among the batch dimenion. start = self.batch_dim_start_index[micro_batch_index] end = self.batch_dim_start_index[micro_batch_index + 1] tokens2use = tokens[start:end, ...] position_ids2use = position_ids[start:end, ...] def _forward_step_helper(model, tokens, position_ids, attention_mask, inference_params, recv_buffer=None): """Single forward step. Update the allocate memory flag so only the first time the memory is allocated.""" batch_size = tokens.size(0) sequence_length = tokens.size(1) if recv_buffer is None: recv_buffer = _allocate_recv_buffer(batch_size, sequence_length) # Receive from previous stage. input_tensor = recv_forward() if not mpu.is_pipeline_first_stage(): torch.distributed.recv(recv_buffer, src=mpu.get_pipeline_model_parallel_prev_rank()) # Forward pass through the model. self.model.set_input_tensor(input_tensor) output_tensor = self.model(tokens2use, position_ids2use, attention_mask, inference_params=self.inference_params) model.set_input_tensor(recv_buffer) output_tensor = model(tokens, position_ids, attention_mask, inference_params=inference_params) # Send output to the next stage. send_forward(output_tensor) if not mpu.is_pipeline_last_stage(): torch.distributed.send(output_tensor, mpu.get_pipeline_model_parallel_next_rank()) # Reset the sequence lenght to whatwever it was before. # Make sure we do not allocate context memory anymore. if self.inference_params.allocate_key_value_memory: self.inference_params.allocate_key_value_memory = False if inference_params.allocate_key_value_memory: inference_params.allocate_key_value_memory = False if mpu.is_pipeline_last_stage(): logits[start:end, ...] = output_tensor # Adjust the sequence length back to whatever it was before. args.seq_length = orig_seq_length return output_tensor return logits def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask, inference_params, recv_buffer=None): # Run a simple forward pass. output_tensor = _forward_step_helper(model, tokens, position_ids, attention_mask, inference_params, recv_buffer=recv_buffer) # Update the sequence length offset. inference_params.sequence_len_offset += tokens.size(1) logits = None if mpu.is_pipeline_last_stage(): logits = output_tensor class NoPipeliningForwardStep(ForwardStepBase): return logits def __init__(self, model, batch_size, max_sequence_len): super().__init__(model) self.inference_params = InferenceParams([batch_size], max_sequence_len) def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask, inference_params, micro_batch_size): sequence_length = tokens.size(1) batch_size = tokens.size(0) def __call__(self, tokens, position_ids, attention_mask): # Divide the batch dimension into micro batches. num_micro_batches, last_chunk = divmod(batch_size, micro_batch_size) if last_chunk > 0: num_micro_batches += 1 # Need to tell p2p_communicate functions the correct size. # Preallocate memory for output logits. logits = None if mpu.is_pipeline_last_stage(): args = get_args() orig_seq_length = args.seq_length args.seq_length = tokens.shape[1] assert args.seq_length <= self.inference_params.max_sequence_len args.micro_batch_size = tokens.shape[0] assert self.inference_params.micro_batch_size_list[0] == tokens.shape[0] assert self.inference_params.micro_batch_index == 0 logits = torch.empty( (batch_size, sequence_length, args.padded_vocab_size), dtype=torch.float32, device=torch.cuda.current_device()) # Receive from previous stage. input_tensor = recv_forward() # Preallocate recv buffer. recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length) # Forward pass through the model. self.model.set_input_tensor(input_tensor) output_tensor = self.model(tokens, position_ids, attention_mask, inference_params=self.inference_params) for micro_batch_index in range(num_micro_batches): # Slice among the batch dimenion. start = micro_batch_index * micro_batch_size end = min(start + micro_batch_size, batch_size) this_micro_batch_size = end - start tokens2use = tokens[start:end, ...] position_ids2use = position_ids[start:end, ...] # Send output to the next stage. send_forward(output_tensor) # Run a simple forward pass. if this_micro_batch_size != micro_batch_size: recv_buffer = None output = _forward_step_helper(model, tokens2use, position_ids2use, attention_mask, inference_params, recv_buffer=recv_buffer) # Reset the sequence lenght to whatwever it was before. args.seq_length = orig_seq_length # Make sure we do not allocate context memory anymore. if self.inference_params.allocate_key_value_memory: self.inference_params.allocate_key_value_memory = False # Adjust the batch size offset to account for the micro-batch. inference_params.batch_size_offset += this_micro_batch_size return output_tensor # Copy logits. if mpu.is_pipeline_last_stage(): logits[start:end, ...] = output # Once we are done with all the micro-batches, we can # adjust the sequence length offset. inference_params.sequence_len_offset += sequence_length # and reset the batch size offset inference_params.batch_size_offset = 0 return logits
megatron/inference/generation.py +4 −5 Original line number Diff line number Diff line Loading @@ -24,7 +24,7 @@ from .communication import ( copy_from_last_to_first_pipeline_stage, broadcast_from_last_pipeline_stage, broadcast_from_last_to_first_pipeline_stage) from .forward_step import forward_step_provider from .forward_step import ForwardStep from .sampling import sample Loading Loading @@ -66,8 +66,7 @@ def generate_tokens_probs_and_return_on_first_stage( max_sequence_length = min(max_sequence_length, args.max_position_embeddings) # forward step. forward_step = forward_step_provider(model, batch_size, 4, max_sequence_length) forward_step = ForwardStep(model, batch_size, max_sequence_length) # Added termination_id to support the case that we want to terminate the # generation once that id is generated. Loading Loading @@ -190,8 +189,8 @@ def generate_tokens_probs_and_return_on_first_stage( done = torch.all(is_generation_done) done = broadcast_from_last_pipeline_stage(1, torch.uint8, tensor=done) if done: break #if done: # break # =================================================== # Update the length of based on max generated length. Loading
megatron/model/transformer.py +25 −31 Original line number Diff line number Diff line Loading @@ -180,9 +180,8 @@ class ParallelAttention(MegatronModule): skip_bias_add=True) # Inference key-value memory self.inference_key_memory_list = None self.inference_value_memory_list = None self.inference_current_sequence_len_list = None self.inference_key_memory = None self.inference_value_memory = None def _allocate_memory(self, inference_max_sequence_len, batch_size): Loading @@ -206,22 +205,17 @@ class ParallelAttention(MegatronModule): if inference_params: if inference_params.allocate_key_value_memory: inf_max_seq_len = inference_params.max_sequence_len inf_batch_sizes = inference_params.micro_batch_size_list self.inference_key_memory_list = [ self._allocate_memory(inf_max_seq_len, inf_batch_size) for inf_batch_size in inf_batch_sizes] self.inference_value_memory_list = [ self._allocate_memory(inf_max_seq_len, inf_batch_size) for inf_batch_size in inf_batch_sizes] self.inference_current_sequence_len_list = [ 0 for _ in inf_batch_sizes] inf_max_batch_size = inference_params.max_batch_size self.inference_key_memory = self._allocate_memory( inf_max_seq_len, inf_max_batch_size) self.inference_value_memory = self._allocate_memory( inf_max_seq_len, inf_max_batch_size) # This is added for safety. In case inference_params # is not provided, make sure there is no potential memory left # from previous inference. else: self.inference_key_memory_list = None self.inference_value_memory_list = None self.inference_current_sequence_len_list = None self.inference_value_memory = None self.inference_current_sequence_len = None # ===================== # Query, Key, and Value Loading Loading @@ -269,23 +263,23 @@ class ParallelAttention(MegatronModule): # ================================== if inference_params: inf_batch_index = inference_params.micro_batch_index assert key_layer.size(1) == \ inference_params.micro_batch_size_list[inf_batch_index] # Adjust the range variables. start = self.inference_current_sequence_len_list[inf_batch_index] end = start + key_layer.size(0) assert end <= inference_params.max_sequence_len self.inference_current_sequence_len_list[inf_batch_index] = end batch_start = inference_params.batch_size_offset batch_end = batch_start + key_layer.size(1) assert batch_end <= self.inference_key_memory.size(1) sequence_start = inference_params.sequence_len_offset sequence_end = sequence_start + key_layer.size(0) assert sequence_end <= self.inference_key_memory.size(0) # Copy key and values. self.inference_key_memory_list[inf_batch_index][start:end, ...] \ = key_layer self.inference_value_memory_list[inf_batch_index][start:end, ...] \ = value_layer key_layer = \ self.inference_key_memory_list[inf_batch_index][:end, ...] value_layer = \ self.inference_value_memory_list[inf_batch_index][:end, ...] self.inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer self.inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer key_layer = self.inference_key_memory[ :sequence_end, batch_start:batch_end, ...] value_layer = self.inference_value_memory[ :sequence_end, batch_start:batch_end, ...] # =================================== Loading