Loading megatron/inference/forward_step.py +114 −41 Original line number Diff line number Diff line Loading @@ -15,15 +15,33 @@ """Forward step utilities.""" from abc import ABC from abc import abstractmethod from collections.abc import Iterable from enum import Enum from megatron.p2p_communication import recv_forward, send_forward from megatron import get_args import torch from megatron import ( get_args, mpu) from megatron.p2p_communication import ( recv_forward, send_forward) class ForwardStepTypes(Enum): NO_PIPELINING = 1 def forward_step_provider(model, batch_size, micro_batch_size, max_sequence_len): args = get_args() if args.pipeline_model_parallel_size == 1 or micro_batch_size >= batch_size: return NoPipeliningForwardStep(model, batch_size, max_sequence_len) return SimplePipeliningForwardStep(model, batch_size, micro_batch_size, max_sequence_len) Loading @@ -37,12 +55,12 @@ class InferenceParams: self.micro_batch_size_list = micro_batch_size_list self.max_sequence_len = max_sequence_len self.allocate_key_value_memory = True self.micro_batch_size_index = 0 self.micro_batch_index = 0 class InferenceForwardStep: class ForwardStepBase(ABC): def __init__(self, model, batch_size, max_sequence_len): def __init__(self, model): if isinstance(model, Iterable): for this_model in model: Loading @@ -51,70 +69,125 @@ class InferenceForwardStep: model.eval() self.model = model self.inference_params = InferenceParams([batch_size], max_sequence_len) self.forward_step_type = ForwardStepTypes.NO_PIPELINING @abstractmethod def __call__(self, tokens, position_ids, attention_mask): pass def __call__(self, tokens, position_ids, attention_mask): if self.forward_step_type == ForwardStepTypes.NO_PIPELINING: return self._forward_step_no_pipelining(tokens, position_ids, attention_mask) class SimplePipeliningForwardStep(ForwardStepBase): raise Exception('unknown forward step type {}'.format( self.forward_step_type)) def __init__(self, model, batch_size, micro_batch_size, max_sequence_len): super().__init__(model) self.batch_size = batch_size # Divide the batch dimension into micro batches. self.num_micro_batches, last_chunk = divmod(batch_size, micro_batch_size) self.micro_batch_size_list = [] self.batch_dim_start_index = [0] for i in range(self.num_micro_batches): self.micro_batch_size_list.append(micro_batch_size) self.batch_dim_start_index.append((i + 1) * micro_batch_size) if last_chunk > 0: self.num_micro_batches += 1 self.micro_batch_size_list.append(last_chunk) self.batch_dim_start_index.append(batch_size) def _forward_step_no_pipelining(self, tokens, position_ids, attention_mask): self.inference_params = InferenceParams(self.micro_batch_size_list, max_sequence_len) def __call__(self, tokens, position_ids, attention_mask): # Need to tell p2p_communicate functions the correct size. args = get_args() orig_seq_length = args.seq_length args.seq_length = tokens.shape[1] args.seq_length = tokens.size(1) assert args.seq_length <= self.inference_params.max_sequence_len args.micro_batch_size = tokens.shape[0] assert self.inference_params.micro_batch_size_list[0] == tokens.shape[0] assert self.inference_params.micro_batch_size_index == 0 # Preallocate memory for output logits. logits = None if mpu.is_pipeline_last_stage(): logits = torch.empty(tokens.size(0), tokens.size(1), args.padded_vocab_size, dtype=torch.float32, device=torch.cuda.current_device()) # Pileline using micro batches. for micro_batch_index in range(self.num_micro_batches): # Set micro-batch size and index. self.inference_params.micro_batch_index = micro_batch_index args.micro_batch_size = self.micro_batch_size_list[ micro_batch_index] # Slice among the batch dimenion. start = self.batch_dim_start_index[micro_batch_index] end = self.batch_dim_start_index[micro_batch_index + 1] tokens2use = tokens[start:end, ...] position_ids2use = position_ids[start:end, ...] # Receive from previous stage. input_tensor = recv_forward() # Forward pass through the model. self.model.set_input_tensor(input_tensor) output_tensor = self.model(tokens, position_ids, attention_mask, output_tensor = self.model(tokens2use, position_ids2use, attention_mask, inference_params=self.inference_params) # Send output to the next stage. send_forward(output_tensor) # Reset the sequence lenght to whatwever it was before. args.seq_length = orig_seq_length # Make sure we do not allocate context memory anymore. if self.inference_params.allocate_key_value_memory: self.inference_params.allocate_key_value_memory = False return output_tensor if mpu.is_pipeline_last_stage(): logits[start:end, ...] = output_tensor # Adjust the sequence length back to whatever it was before. args.seq_length = orig_seq_length return logits def forward_step(model, tokens, position_ids, attention_mask, inference_params): # Hidden size changes when not using recompute, need to tell p2p_communicate # functions the correct size class NoPipeliningForwardStep(ForwardStepBase): def __init__(self, model, batch_size, max_sequence_len): super().__init__(model) self.inference_params = InferenceParams([batch_size], max_sequence_len) def __call__(self, tokens, position_ids, attention_mask): # Need to tell p2p_communicate functions the correct size. args = get_args() orig_seq_length = args.seq_length args.seq_length = tokens.shape[1] assert args.seq_length <= self.inference_params.max_sequence_len args.micro_batch_size = tokens.shape[0] assert self.inference_params.micro_batch_size_list[0] == tokens.shape[0] assert self.inference_params.micro_batch_index == 0 # Receive from previous stage. input_tensor = recv_forward() # Forward pass through the model. model.set_input_tensor(input_tensor) output_tensor = model(tokens, position_ids, attention_mask, inference_params=inference_params) self.model.set_input_tensor(input_tensor) output_tensor = self.model(tokens, position_ids, attention_mask, inference_params=self.inference_params) # Send output to the next stage. send_forward(output_tensor) # Reset the sequence lenght to whatwever it was before. args.seq_length = orig_seq_length # Make sure we do not allocate context memory anymore. if self.inference_params.allocate_key_value_memory: self.inference_params.allocate_key_value_memory = False return output_tensor megatron/inference/generation.py +3 −2 Original line number Diff line number Diff line Loading @@ -24,7 +24,7 @@ from .communication import ( copy_from_last_to_first_pipeline_stage, broadcast_from_last_pipeline_stage, broadcast_from_last_to_first_pipeline_stage) from .forward_step import InferenceForwardStep from .forward_step import forward_step_provider from .sampling import sample Loading Loading @@ -66,7 +66,8 @@ def generate_tokens_probs_and_return_on_first_stage( max_sequence_length = min(max_sequence_length, args.max_position_embeddings) # forward step. forward_step = InferenceForwardStep(model, batch_size, max_sequence_length) forward_step = forward_step_provider(model, batch_size, 4, max_sequence_length) # Added termination_id to support the case that we want to terminate the # generation once that id is generated. Loading megatron/model/transformer.py +6 −5 Original line number Diff line number Diff line Loading @@ -269,18 +269,19 @@ class ParallelAttention(MegatronModule): # ================================== if inference_params: inf_batch_index = inference_params.micro_batch_size_index inf_batch_index = inference_params.micro_batch_index assert key_layer.size(1) == \ inference_params.micro_batch_size_list[inf_batch_index] # Adjust the range variables. start = self.inference_current_sequence_len_list[inf_batch_index] end = start + key_layer.size(0) assert end <= inference_params.max_sequence_len self.inference_current_sequence_len_list[inf_batch_index] = end # Copy key and values. self.inference_key_memory_list[inf_batch_index][start:end, ...] =\ key_layer self.inference_value_memory_list[inf_batch_index][start:end, ...] =\ value_layer self.inference_key_memory_list[inf_batch_index][start:end, ...] \ = key_layer self.inference_value_memory_list[inf_batch_index][start:end, ...] \ = value_layer key_layer = \ self.inference_key_memory_list[inf_batch_index][:end, ...] value_layer = \ Loading Loading
megatron/inference/forward_step.py +114 −41 Original line number Diff line number Diff line Loading @@ -15,15 +15,33 @@ """Forward step utilities.""" from abc import ABC from abc import abstractmethod from collections.abc import Iterable from enum import Enum from megatron.p2p_communication import recv_forward, send_forward from megatron import get_args import torch from megatron import ( get_args, mpu) from megatron.p2p_communication import ( recv_forward, send_forward) class ForwardStepTypes(Enum): NO_PIPELINING = 1 def forward_step_provider(model, batch_size, micro_batch_size, max_sequence_len): args = get_args() if args.pipeline_model_parallel_size == 1 or micro_batch_size >= batch_size: return NoPipeliningForwardStep(model, batch_size, max_sequence_len) return SimplePipeliningForwardStep(model, batch_size, micro_batch_size, max_sequence_len) Loading @@ -37,12 +55,12 @@ class InferenceParams: self.micro_batch_size_list = micro_batch_size_list self.max_sequence_len = max_sequence_len self.allocate_key_value_memory = True self.micro_batch_size_index = 0 self.micro_batch_index = 0 class InferenceForwardStep: class ForwardStepBase(ABC): def __init__(self, model, batch_size, max_sequence_len): def __init__(self, model): if isinstance(model, Iterable): for this_model in model: Loading @@ -51,70 +69,125 @@ class InferenceForwardStep: model.eval() self.model = model self.inference_params = InferenceParams([batch_size], max_sequence_len) self.forward_step_type = ForwardStepTypes.NO_PIPELINING @abstractmethod def __call__(self, tokens, position_ids, attention_mask): pass def __call__(self, tokens, position_ids, attention_mask): if self.forward_step_type == ForwardStepTypes.NO_PIPELINING: return self._forward_step_no_pipelining(tokens, position_ids, attention_mask) class SimplePipeliningForwardStep(ForwardStepBase): raise Exception('unknown forward step type {}'.format( self.forward_step_type)) def __init__(self, model, batch_size, micro_batch_size, max_sequence_len): super().__init__(model) self.batch_size = batch_size # Divide the batch dimension into micro batches. self.num_micro_batches, last_chunk = divmod(batch_size, micro_batch_size) self.micro_batch_size_list = [] self.batch_dim_start_index = [0] for i in range(self.num_micro_batches): self.micro_batch_size_list.append(micro_batch_size) self.batch_dim_start_index.append((i + 1) * micro_batch_size) if last_chunk > 0: self.num_micro_batches += 1 self.micro_batch_size_list.append(last_chunk) self.batch_dim_start_index.append(batch_size) def _forward_step_no_pipelining(self, tokens, position_ids, attention_mask): self.inference_params = InferenceParams(self.micro_batch_size_list, max_sequence_len) def __call__(self, tokens, position_ids, attention_mask): # Need to tell p2p_communicate functions the correct size. args = get_args() orig_seq_length = args.seq_length args.seq_length = tokens.shape[1] args.seq_length = tokens.size(1) assert args.seq_length <= self.inference_params.max_sequence_len args.micro_batch_size = tokens.shape[0] assert self.inference_params.micro_batch_size_list[0] == tokens.shape[0] assert self.inference_params.micro_batch_size_index == 0 # Preallocate memory for output logits. logits = None if mpu.is_pipeline_last_stage(): logits = torch.empty(tokens.size(0), tokens.size(1), args.padded_vocab_size, dtype=torch.float32, device=torch.cuda.current_device()) # Pileline using micro batches. for micro_batch_index in range(self.num_micro_batches): # Set micro-batch size and index. self.inference_params.micro_batch_index = micro_batch_index args.micro_batch_size = self.micro_batch_size_list[ micro_batch_index] # Slice among the batch dimenion. start = self.batch_dim_start_index[micro_batch_index] end = self.batch_dim_start_index[micro_batch_index + 1] tokens2use = tokens[start:end, ...] position_ids2use = position_ids[start:end, ...] # Receive from previous stage. input_tensor = recv_forward() # Forward pass through the model. self.model.set_input_tensor(input_tensor) output_tensor = self.model(tokens, position_ids, attention_mask, output_tensor = self.model(tokens2use, position_ids2use, attention_mask, inference_params=self.inference_params) # Send output to the next stage. send_forward(output_tensor) # Reset the sequence lenght to whatwever it was before. args.seq_length = orig_seq_length # Make sure we do not allocate context memory anymore. if self.inference_params.allocate_key_value_memory: self.inference_params.allocate_key_value_memory = False return output_tensor if mpu.is_pipeline_last_stage(): logits[start:end, ...] = output_tensor # Adjust the sequence length back to whatever it was before. args.seq_length = orig_seq_length return logits def forward_step(model, tokens, position_ids, attention_mask, inference_params): # Hidden size changes when not using recompute, need to tell p2p_communicate # functions the correct size class NoPipeliningForwardStep(ForwardStepBase): def __init__(self, model, batch_size, max_sequence_len): super().__init__(model) self.inference_params = InferenceParams([batch_size], max_sequence_len) def __call__(self, tokens, position_ids, attention_mask): # Need to tell p2p_communicate functions the correct size. args = get_args() orig_seq_length = args.seq_length args.seq_length = tokens.shape[1] assert args.seq_length <= self.inference_params.max_sequence_len args.micro_batch_size = tokens.shape[0] assert self.inference_params.micro_batch_size_list[0] == tokens.shape[0] assert self.inference_params.micro_batch_index == 0 # Receive from previous stage. input_tensor = recv_forward() # Forward pass through the model. model.set_input_tensor(input_tensor) output_tensor = model(tokens, position_ids, attention_mask, inference_params=inference_params) self.model.set_input_tensor(input_tensor) output_tensor = self.model(tokens, position_ids, attention_mask, inference_params=self.inference_params) # Send output to the next stage. send_forward(output_tensor) # Reset the sequence lenght to whatwever it was before. args.seq_length = orig_seq_length # Make sure we do not allocate context memory anymore. if self.inference_params.allocate_key_value_memory: self.inference_params.allocate_key_value_memory = False return output_tensor
megatron/inference/generation.py +3 −2 Original line number Diff line number Diff line Loading @@ -24,7 +24,7 @@ from .communication import ( copy_from_last_to_first_pipeline_stage, broadcast_from_last_pipeline_stage, broadcast_from_last_to_first_pipeline_stage) from .forward_step import InferenceForwardStep from .forward_step import forward_step_provider from .sampling import sample Loading Loading @@ -66,7 +66,8 @@ def generate_tokens_probs_and_return_on_first_stage( max_sequence_length = min(max_sequence_length, args.max_position_embeddings) # forward step. forward_step = InferenceForwardStep(model, batch_size, max_sequence_length) forward_step = forward_step_provider(model, batch_size, 4, max_sequence_length) # Added termination_id to support the case that we want to terminate the # generation once that id is generated. Loading
megatron/model/transformer.py +6 −5 Original line number Diff line number Diff line Loading @@ -269,18 +269,19 @@ class ParallelAttention(MegatronModule): # ================================== if inference_params: inf_batch_index = inference_params.micro_batch_size_index inf_batch_index = inference_params.micro_batch_index assert key_layer.size(1) == \ inference_params.micro_batch_size_list[inf_batch_index] # Adjust the range variables. start = self.inference_current_sequence_len_list[inf_batch_index] end = start + key_layer.size(0) assert end <= inference_params.max_sequence_len self.inference_current_sequence_len_list[inf_batch_index] = end # Copy key and values. self.inference_key_memory_list[inf_batch_index][start:end, ...] =\ key_layer self.inference_value_memory_list[inf_batch_index][start:end, ...] =\ value_layer self.inference_key_memory_list[inf_batch_index][start:end, ...] \ = key_layer self.inference_value_memory_list[inf_batch_index][start:end, ...] \ = value_layer key_layer = \ self.inference_key_memory_list[inf_batch_index][:end, ...] value_layer = \ Loading