Loading megatron/inference/forward_step.py +16 −7 Original line number Diff line number Diff line Loading @@ -22,9 +22,20 @@ from megatron.p2p_communication import recv_forward, send_forward from megatron import get_args def forward_step(model, tokens, position_ids, attention_mask, set_inference_key_value_memory=False, inference_max_sequence_len=None): class InferenceParams: def __init__(self, micro_batch_size_list, max_sequence_len): assert isinstance(micro_batch_size_list, list) assert max_sequence_len > 0 self.micro_batch_size_list = micro_batch_size_list self.max_sequence_len = max_sequence_len self.allocate_key_value_memory = False self.micro_batch_size_index = 0 def forward_step(model, tokens, position_ids, attention_mask, inference_params): # Hidden size changes when not using recompute, need to tell p2p_communicate # functions the correct size Loading @@ -37,10 +48,8 @@ def forward_step(model, tokens, position_ids, attention_mask, # Forward pass through the model. model.set_input_tensor(input_tensor) output_tensor = model( tokens, position_ids, attention_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) output_tensor = model(tokens, position_ids, attention_mask, inference_params=inference_params) send_forward(output_tensor) Loading megatron/inference/generation.py +8 −6 Original line number Diff line number Diff line Loading @@ -25,7 +25,7 @@ from .communication import ( copy_from_last_to_first_pipeline_stage, broadcast_from_last_pipeline_stage, broadcast_from_last_to_first_pipeline_stage) from .forward_step import forward_step from .forward_step import forward_step, InferenceParams from .sampling import sample Loading Loading @@ -109,6 +109,9 @@ def generate_tokens_probs_and_return_on_first_stage( attention_mask, position_ids = _build_attention_mask_and_position_ids( tokens) # Set inference params inference_params = InferenceParams([batch_size], max_sequence_length) model.eval() with torch.no_grad(): prev_context_length = 0 Loading @@ -117,7 +120,8 @@ def generate_tokens_probs_and_return_on_first_stage( # If we are starting from scratch, allocate memory for the entire # context, otherwise set this to false so the memory is not # reallocated. set_inference_key_value_memory = (prev_context_length == 0) inference_params.allocate_key_value_memory = \ (prev_context_length == 0) # Pick the slice that we need to pass through the network. tokens2use = tokens[:, prev_context_length:context_length] Loading @@ -126,10 +130,8 @@ def generate_tokens_probs_and_return_on_first_stage( ..., prev_context_length:context_length, :context_length] # logits will be meanigful only in the last pipeline stage. logits = forward_step( model, tokens2use, positions2use, attention_mask2use, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=max_sequence_length) logits = forward_step(model, tokens2use, positions2use, attention_mask2use, inference_params) if mpu.is_pipeline_last_stage(): # Always the last stage should have an output. Loading megatron/model/gpt_model.py +2 −5 Original line number Diff line number Diff line Loading @@ -82,16 +82,13 @@ class GPTModel(MegatronModule): self.language_model.set_input_tensor(input_tensor) def forward(self, input_ids, position_ids, attention_mask, labels=None, tokentype_ids=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): tokentype_ids=None, inference_params=None): lm_output = self.language_model( input_ids, position_ids, attention_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) inference_params=inference_params) if self.post_process: return post_language_model_processing( Loading megatron/model/language_model.py +3 −6 Original line number Diff line number Diff line Loading @@ -335,8 +335,7 @@ class TransformerLanguageModel(MegatronModule): def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, enc_dec_attn_mask=None, tokentype_ids=None, set_inference_key_value_memory=False, inference_max_sequence_len=None, inference_params=None, pooling_sequence_index=0, enc_hidden_states=None, output_enc_hidden=False): Loading @@ -353,8 +352,7 @@ class TransformerLanguageModel(MegatronModule): encoder_output = self.encoder( encoder_input, enc_attn_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) inference_params=inference_params) else: encoder_output = enc_hidden_states.to(encoder_input.dtype) Loading @@ -381,8 +379,7 @@ class TransformerLanguageModel(MegatronModule): dec_attn_mask, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) inference_params=inference_params) if self.add_pooler and self.post_process: return decoder_output, encoder_output, pooled_output Loading megatron/model/transformer.py +47 −48 Original line number Diff line number Diff line Loading @@ -180,9 +180,9 @@ class ParallelAttention(MegatronModule): skip_bias_add=True) # Inference key-value memory self.inference_key_memory = None self.inference_value_memory = None self.inference_current_sequence_len = 0 self.inference_key_memory_list = None self.inference_value_memory_list = None self.inference_current_sequence_len_list = None def _allocate_memory(self, inference_max_sequence_len, batch_size): Loading @@ -196,35 +196,32 @@ class ParallelAttention(MegatronModule): def forward(self, hidden_states, attention_mask, encoder_output=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): encoder_output=None, inference_params=None): # hidden_states: [sq, b, h] # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= if set_inference_key_value_memory: assert inference_max_sequence_len and inference_max_sequence_len > 0 self.inference_key_memory = self._allocate_memory( inference_max_sequence_len, hidden_states.size(1)) self.inference_value_memory = self._allocate_memory( inference_max_sequence_len, hidden_states.size(1)) self.inference_current_sequence_len = 0 # Some consistency check. if inference_max_sequence_len: assert self.inference_current_sequence_len < \ self.inference_key_memory.size(0) assert inference_max_sequence_len == \ self.inference_key_memory.size(0) # This is added for safety. In case inference_max_sequence_len if inference_params: if inference_params.allocate_key_value_memory: inf_max_seq_len = inference_params.max_sequence_len inf_batch_sizes = inference_params.micro_batch_size_list self.inference_key_memory_list = [ self._allocate_memory(inf_max_seq_len, inf_batch_size) for inf_batch_size in inf_batch_sizes] self.inference_value_memory_list = [ self._allocate_memory(inf_max_seq_len, inf_batch_size) for inf_batch_size in inf_batch_sizes] self.inference_current_sequence_len_list = [ 0 for _ in inf_batch_sizes] # This is added for safety. In case inference_params # is not provided, make sure there is no potential memory left # from previous inference. if not inference_max_sequence_len: self.inference_key_memory = None self.inference_value_memory = None else: self.inference_key_memory_list = None self.inference_value_memory_list = None self.inference_current_sequence_len_list = None # ===================== # Query, Key, and Value Loading Loading @@ -267,20 +264,27 @@ class ParallelAttention(MegatronModule): query_layer = query_layer.view(*new_tensor_shape) # =================================================== # Adjust key, value, and attention mask for inference # =================================================== # ================================== # Adjust key and value for inference # ================================== if inference_max_sequence_len: if inference_params: inf_batch_index = inference_params.micro_batch_size_index assert key_layer.size(1) == \ inference_params.micro_batch_size_list[inf_batch_index] # Adjust the range variables. start = self.inference_current_sequence_len self.inference_current_sequence_len += key_layer.size(0) end = self.inference_current_sequence_len start = self.inference_current_sequence_len_list[inf_batch_index] end = start + key_layer.size(0) self.inference_current_sequence_len_list[inf_batch_index] = end # Copy key and values. self.inference_key_memory[start:end, ...] = key_layer self.inference_value_memory[start:end, ...] = value_layer key_layer = self.inference_key_memory[:end, ...] value_layer = self.inference_value_memory[:end, ...] self.inference_key_memory_list[inf_batch_index][start:end, ...] =\ key_layer self.inference_value_memory_list[inf_batch_index][start:end, ...] =\ value_layer key_layer = \ self.inference_key_memory_list[inf_batch_index][:end, ...] value_layer = \ self.inference_value_memory_list[inf_batch_index][:end, ...] # =================================== Loading Loading @@ -459,10 +463,8 @@ class ParallelTransformerLayer(MegatronModule): output_layer_init_method) def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): encoder_output=None, enc_dec_attn_mask=None, inference_params=None): # hidden_states: [b, s, h] # Layer norm at the beginning of the transformer layer. Loading @@ -472,8 +474,7 @@ class ParallelTransformerLayer(MegatronModule): self.self_attention( layernorm_output, attention_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) inference_params=inference_params) # Residual connection. if self.apply_residual_connection_post_layernorm: Loading Loading @@ -686,13 +687,11 @@ class ParallelTransformer(MegatronModule): self.input_tensor = input_tensor def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): encoder_output=None, enc_dec_attn_mask=None, inference_params=None): # Checks. if inference_max_sequence_len: if inference_params: assert self.activations_checkpoint_method is None, \ 'inference does not work with activation checkpointing' Loading Loading @@ -724,8 +723,8 @@ class ParallelTransformer(MegatronModule): attention_mask, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) inference_params=inference_params) # Final layer norm. if self.post_process: Loading Loading
megatron/inference/forward_step.py +16 −7 Original line number Diff line number Diff line Loading @@ -22,9 +22,20 @@ from megatron.p2p_communication import recv_forward, send_forward from megatron import get_args def forward_step(model, tokens, position_ids, attention_mask, set_inference_key_value_memory=False, inference_max_sequence_len=None): class InferenceParams: def __init__(self, micro_batch_size_list, max_sequence_len): assert isinstance(micro_batch_size_list, list) assert max_sequence_len > 0 self.micro_batch_size_list = micro_batch_size_list self.max_sequence_len = max_sequence_len self.allocate_key_value_memory = False self.micro_batch_size_index = 0 def forward_step(model, tokens, position_ids, attention_mask, inference_params): # Hidden size changes when not using recompute, need to tell p2p_communicate # functions the correct size Loading @@ -37,10 +48,8 @@ def forward_step(model, tokens, position_ids, attention_mask, # Forward pass through the model. model.set_input_tensor(input_tensor) output_tensor = model( tokens, position_ids, attention_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) output_tensor = model(tokens, position_ids, attention_mask, inference_params=inference_params) send_forward(output_tensor) Loading
megatron/inference/generation.py +8 −6 Original line number Diff line number Diff line Loading @@ -25,7 +25,7 @@ from .communication import ( copy_from_last_to_first_pipeline_stage, broadcast_from_last_pipeline_stage, broadcast_from_last_to_first_pipeline_stage) from .forward_step import forward_step from .forward_step import forward_step, InferenceParams from .sampling import sample Loading Loading @@ -109,6 +109,9 @@ def generate_tokens_probs_and_return_on_first_stage( attention_mask, position_ids = _build_attention_mask_and_position_ids( tokens) # Set inference params inference_params = InferenceParams([batch_size], max_sequence_length) model.eval() with torch.no_grad(): prev_context_length = 0 Loading @@ -117,7 +120,8 @@ def generate_tokens_probs_and_return_on_first_stage( # If we are starting from scratch, allocate memory for the entire # context, otherwise set this to false so the memory is not # reallocated. set_inference_key_value_memory = (prev_context_length == 0) inference_params.allocate_key_value_memory = \ (prev_context_length == 0) # Pick the slice that we need to pass through the network. tokens2use = tokens[:, prev_context_length:context_length] Loading @@ -126,10 +130,8 @@ def generate_tokens_probs_and_return_on_first_stage( ..., prev_context_length:context_length, :context_length] # logits will be meanigful only in the last pipeline stage. logits = forward_step( model, tokens2use, positions2use, attention_mask2use, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=max_sequence_length) logits = forward_step(model, tokens2use, positions2use, attention_mask2use, inference_params) if mpu.is_pipeline_last_stage(): # Always the last stage should have an output. Loading
megatron/model/gpt_model.py +2 −5 Original line number Diff line number Diff line Loading @@ -82,16 +82,13 @@ class GPTModel(MegatronModule): self.language_model.set_input_tensor(input_tensor) def forward(self, input_ids, position_ids, attention_mask, labels=None, tokentype_ids=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): tokentype_ids=None, inference_params=None): lm_output = self.language_model( input_ids, position_ids, attention_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) inference_params=inference_params) if self.post_process: return post_language_model_processing( Loading
megatron/model/language_model.py +3 −6 Original line number Diff line number Diff line Loading @@ -335,8 +335,7 @@ class TransformerLanguageModel(MegatronModule): def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, enc_dec_attn_mask=None, tokentype_ids=None, set_inference_key_value_memory=False, inference_max_sequence_len=None, inference_params=None, pooling_sequence_index=0, enc_hidden_states=None, output_enc_hidden=False): Loading @@ -353,8 +352,7 @@ class TransformerLanguageModel(MegatronModule): encoder_output = self.encoder( encoder_input, enc_attn_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) inference_params=inference_params) else: encoder_output = enc_hidden_states.to(encoder_input.dtype) Loading @@ -381,8 +379,7 @@ class TransformerLanguageModel(MegatronModule): dec_attn_mask, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) inference_params=inference_params) if self.add_pooler and self.post_process: return decoder_output, encoder_output, pooled_output Loading
megatron/model/transformer.py +47 −48 Original line number Diff line number Diff line Loading @@ -180,9 +180,9 @@ class ParallelAttention(MegatronModule): skip_bias_add=True) # Inference key-value memory self.inference_key_memory = None self.inference_value_memory = None self.inference_current_sequence_len = 0 self.inference_key_memory_list = None self.inference_value_memory_list = None self.inference_current_sequence_len_list = None def _allocate_memory(self, inference_max_sequence_len, batch_size): Loading @@ -196,35 +196,32 @@ class ParallelAttention(MegatronModule): def forward(self, hidden_states, attention_mask, encoder_output=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): encoder_output=None, inference_params=None): # hidden_states: [sq, b, h] # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= if set_inference_key_value_memory: assert inference_max_sequence_len and inference_max_sequence_len > 0 self.inference_key_memory = self._allocate_memory( inference_max_sequence_len, hidden_states.size(1)) self.inference_value_memory = self._allocate_memory( inference_max_sequence_len, hidden_states.size(1)) self.inference_current_sequence_len = 0 # Some consistency check. if inference_max_sequence_len: assert self.inference_current_sequence_len < \ self.inference_key_memory.size(0) assert inference_max_sequence_len == \ self.inference_key_memory.size(0) # This is added for safety. In case inference_max_sequence_len if inference_params: if inference_params.allocate_key_value_memory: inf_max_seq_len = inference_params.max_sequence_len inf_batch_sizes = inference_params.micro_batch_size_list self.inference_key_memory_list = [ self._allocate_memory(inf_max_seq_len, inf_batch_size) for inf_batch_size in inf_batch_sizes] self.inference_value_memory_list = [ self._allocate_memory(inf_max_seq_len, inf_batch_size) for inf_batch_size in inf_batch_sizes] self.inference_current_sequence_len_list = [ 0 for _ in inf_batch_sizes] # This is added for safety. In case inference_params # is not provided, make sure there is no potential memory left # from previous inference. if not inference_max_sequence_len: self.inference_key_memory = None self.inference_value_memory = None else: self.inference_key_memory_list = None self.inference_value_memory_list = None self.inference_current_sequence_len_list = None # ===================== # Query, Key, and Value Loading Loading @@ -267,20 +264,27 @@ class ParallelAttention(MegatronModule): query_layer = query_layer.view(*new_tensor_shape) # =================================================== # Adjust key, value, and attention mask for inference # =================================================== # ================================== # Adjust key and value for inference # ================================== if inference_max_sequence_len: if inference_params: inf_batch_index = inference_params.micro_batch_size_index assert key_layer.size(1) == \ inference_params.micro_batch_size_list[inf_batch_index] # Adjust the range variables. start = self.inference_current_sequence_len self.inference_current_sequence_len += key_layer.size(0) end = self.inference_current_sequence_len start = self.inference_current_sequence_len_list[inf_batch_index] end = start + key_layer.size(0) self.inference_current_sequence_len_list[inf_batch_index] = end # Copy key and values. self.inference_key_memory[start:end, ...] = key_layer self.inference_value_memory[start:end, ...] = value_layer key_layer = self.inference_key_memory[:end, ...] value_layer = self.inference_value_memory[:end, ...] self.inference_key_memory_list[inf_batch_index][start:end, ...] =\ key_layer self.inference_value_memory_list[inf_batch_index][start:end, ...] =\ value_layer key_layer = \ self.inference_key_memory_list[inf_batch_index][:end, ...] value_layer = \ self.inference_value_memory_list[inf_batch_index][:end, ...] # =================================== Loading Loading @@ -459,10 +463,8 @@ class ParallelTransformerLayer(MegatronModule): output_layer_init_method) def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): encoder_output=None, enc_dec_attn_mask=None, inference_params=None): # hidden_states: [b, s, h] # Layer norm at the beginning of the transformer layer. Loading @@ -472,8 +474,7 @@ class ParallelTransformerLayer(MegatronModule): self.self_attention( layernorm_output, attention_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) inference_params=inference_params) # Residual connection. if self.apply_residual_connection_post_layernorm: Loading Loading @@ -686,13 +687,11 @@ class ParallelTransformer(MegatronModule): self.input_tensor = input_tensor def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): encoder_output=None, enc_dec_attn_mask=None, inference_params=None): # Checks. if inference_max_sequence_len: if inference_params: assert self.activations_checkpoint_method is None, \ 'inference does not work with activation checkpointing' Loading Loading @@ -724,8 +723,8 @@ class ParallelTransformer(MegatronModule): attention_mask, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) inference_params=inference_params) # Final layer norm. if self.post_process: Loading