Loading megatron/inference/forward_step.py +66 −8 Original line number Diff line number Diff line Loading @@ -15,13 +15,18 @@ """Forward step utilities.""" import torch from collections.abc import Iterable from enum import Enum from megatron.p2p_communication import recv_forward, send_forward from megatron import get_args class ForwardStepTypes(Enum): NO_PIPELINING = 1 class InferenceParams: def __init__(self, micro_batch_size_list, max_sequence_len): Loading @@ -31,10 +36,67 @@ class InferenceParams: self.micro_batch_size_list = micro_batch_size_list self.max_sequence_len = max_sequence_len self.allocate_key_value_memory = False self.allocate_key_value_memory = True self.micro_batch_size_index = 0 class InferenceForwardStep: def __init__(self, model, batch_size, max_sequence_len): if isinstance(model, Iterable): for this_model in model: this_model.eval() else: model.eval() self.model = model self.inference_params = InferenceParams([batch_size], max_sequence_len) self.forward_step_type = ForwardStepTypes.NO_PIPELINING def __call__(self, tokens, position_ids, attention_mask): if self.forward_step_type == ForwardStepTypes.NO_PIPELINING: return self._forward_step_no_pipelining(tokens, position_ids, attention_mask) raise Exception('unknown forward step type {}'.format( self.forward_step_type)) def _forward_step_no_pipelining(self, tokens, position_ids, attention_mask): # Need to tell p2p_communicate functions the correct size. args = get_args() orig_seq_length = args.seq_length args.seq_length = tokens.shape[1] assert args.seq_length <= self.inference_params.max_sequence_len args.micro_batch_size = tokens.shape[0] assert self.inference_params.micro_batch_size_list[0] == tokens.shape[0] assert self.inference_params.micro_batch_size_index == 0 # Receive from previous stage. input_tensor = recv_forward() # Forward pass through the model. self.model.set_input_tensor(input_tensor) output_tensor = self.model(tokens, position_ids, attention_mask, inference_params=self.inference_params) # Send output to the next stage. send_forward(output_tensor) # Reset the sequence lenght to whatwever it was before. args.seq_length = orig_seq_length # Make sure we do not allocate context memory anymore. if self.inference_params.allocate_key_value_memory: self.inference_params.allocate_key_value_memory = False return output_tensor def forward_step(model, tokens, position_ids, attention_mask, inference_params): # Hidden size changes when not using recompute, need to tell p2p_communicate Loading @@ -56,7 +118,3 @@ def forward_step(model, tokens, position_ids, attention_mask, inference_params): args.seq_length = orig_seq_length return output_tensor megatron/inference/generation.py +5 −14 Original line number Diff line number Diff line Loading @@ -15,7 +15,6 @@ """Generation utilities.""" import torch import torch.nn.functional as F Loading @@ -25,7 +24,7 @@ from .communication import ( copy_from_last_to_first_pipeline_stage, broadcast_from_last_pipeline_stage, broadcast_from_last_to_first_pipeline_stage) from .forward_step import forward_step, InferenceParams from .forward_step import InferenceForwardStep from .sampling import sample Loading Loading @@ -66,6 +65,9 @@ def generate_tokens_probs_and_return_on_first_stage( max_sequence_length = tokens.size(1) max_sequence_length = min(max_sequence_length, args.max_position_embeddings) # forward step. forward_step = InferenceForwardStep(model, batch_size, max_sequence_length) # Added termination_id to support the case that we want to terminate the # generation once that id is generated. if hasattr(args, 'eos_id'): Loading Loading @@ -109,20 +111,10 @@ def generate_tokens_probs_and_return_on_first_stage( attention_mask, position_ids = _build_attention_mask_and_position_ids( tokens) # Set inference params inference_params = InferenceParams([batch_size], max_sequence_length) model.eval() with torch.no_grad(): prev_context_length = 0 for context_length in range(min_prompt_length, max_sequence_length): # If we are starting from scratch, allocate memory for the entire # context, otherwise set this to false so the memory is not # reallocated. inference_params.allocate_key_value_memory = \ (prev_context_length == 0) # Pick the slice that we need to pass through the network. tokens2use = tokens[:, prev_context_length:context_length] positions2use = position_ids[:, prev_context_length:context_length] Loading @@ -130,8 +122,7 @@ def generate_tokens_probs_and_return_on_first_stage( ..., prev_context_length:context_length, :context_length] # logits will be meanigful only in the last pipeline stage. logits = forward_step(model, tokens2use, positions2use, attention_mask2use, inference_params) logits = forward_step(tokens2use, positions2use, attention_mask2use) if mpu.is_pipeline_last_stage(): # Always the last stage should have an output. Loading Loading
megatron/inference/forward_step.py +66 −8 Original line number Diff line number Diff line Loading @@ -15,13 +15,18 @@ """Forward step utilities.""" import torch from collections.abc import Iterable from enum import Enum from megatron.p2p_communication import recv_forward, send_forward from megatron import get_args class ForwardStepTypes(Enum): NO_PIPELINING = 1 class InferenceParams: def __init__(self, micro_batch_size_list, max_sequence_len): Loading @@ -31,10 +36,67 @@ class InferenceParams: self.micro_batch_size_list = micro_batch_size_list self.max_sequence_len = max_sequence_len self.allocate_key_value_memory = False self.allocate_key_value_memory = True self.micro_batch_size_index = 0 class InferenceForwardStep: def __init__(self, model, batch_size, max_sequence_len): if isinstance(model, Iterable): for this_model in model: this_model.eval() else: model.eval() self.model = model self.inference_params = InferenceParams([batch_size], max_sequence_len) self.forward_step_type = ForwardStepTypes.NO_PIPELINING def __call__(self, tokens, position_ids, attention_mask): if self.forward_step_type == ForwardStepTypes.NO_PIPELINING: return self._forward_step_no_pipelining(tokens, position_ids, attention_mask) raise Exception('unknown forward step type {}'.format( self.forward_step_type)) def _forward_step_no_pipelining(self, tokens, position_ids, attention_mask): # Need to tell p2p_communicate functions the correct size. args = get_args() orig_seq_length = args.seq_length args.seq_length = tokens.shape[1] assert args.seq_length <= self.inference_params.max_sequence_len args.micro_batch_size = tokens.shape[0] assert self.inference_params.micro_batch_size_list[0] == tokens.shape[0] assert self.inference_params.micro_batch_size_index == 0 # Receive from previous stage. input_tensor = recv_forward() # Forward pass through the model. self.model.set_input_tensor(input_tensor) output_tensor = self.model(tokens, position_ids, attention_mask, inference_params=self.inference_params) # Send output to the next stage. send_forward(output_tensor) # Reset the sequence lenght to whatwever it was before. args.seq_length = orig_seq_length # Make sure we do not allocate context memory anymore. if self.inference_params.allocate_key_value_memory: self.inference_params.allocate_key_value_memory = False return output_tensor def forward_step(model, tokens, position_ids, attention_mask, inference_params): # Hidden size changes when not using recompute, need to tell p2p_communicate Loading @@ -56,7 +118,3 @@ def forward_step(model, tokens, position_ids, attention_mask, inference_params): args.seq_length = orig_seq_length return output_tensor
megatron/inference/generation.py +5 −14 Original line number Diff line number Diff line Loading @@ -15,7 +15,6 @@ """Generation utilities.""" import torch import torch.nn.functional as F Loading @@ -25,7 +24,7 @@ from .communication import ( copy_from_last_to_first_pipeline_stage, broadcast_from_last_pipeline_stage, broadcast_from_last_to_first_pipeline_stage) from .forward_step import forward_step, InferenceParams from .forward_step import InferenceForwardStep from .sampling import sample Loading Loading @@ -66,6 +65,9 @@ def generate_tokens_probs_and_return_on_first_stage( max_sequence_length = tokens.size(1) max_sequence_length = min(max_sequence_length, args.max_position_embeddings) # forward step. forward_step = InferenceForwardStep(model, batch_size, max_sequence_length) # Added termination_id to support the case that we want to terminate the # generation once that id is generated. if hasattr(args, 'eos_id'): Loading Loading @@ -109,20 +111,10 @@ def generate_tokens_probs_and_return_on_first_stage( attention_mask, position_ids = _build_attention_mask_and_position_ids( tokens) # Set inference params inference_params = InferenceParams([batch_size], max_sequence_length) model.eval() with torch.no_grad(): prev_context_length = 0 for context_length in range(min_prompt_length, max_sequence_length): # If we are starting from scratch, allocate memory for the entire # context, otherwise set this to false so the memory is not # reallocated. inference_params.allocate_key_value_memory = \ (prev_context_length == 0) # Pick the slice that we need to pass through the network. tokens2use = tokens[:, prev_context_length:context_length] positions2use = position_ids[:, prev_context_length:context_length] Loading @@ -130,8 +122,7 @@ def generate_tokens_probs_and_return_on_first_stage( ..., prev_context_length:context_length, :context_length] # logits will be meanigful only in the last pipeline stage. logits = forward_step(model, tokens2use, positions2use, attention_mask2use, inference_params) logits = forward_step(tokens2use, positions2use, attention_mask2use) if mpu.is_pipeline_last_stage(): # Always the last stage should have an output. Loading