Loading megatron/arguments.py +13 −0 Original line number Diff line number Diff line Loading @@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={}, parser = _add_biencoder_args(parser) parser = _add_vit_args(parser) parser = _add_logging_args(parser) parser = _add_inference_args(parser) # Custom arguments. if extra_args_provider is not None: Loading Loading @@ -273,6 +274,18 @@ def _check_arg_is_not_none(args, arg): assert getattr(args, arg) is not None, '{} argument is None'.format(arg) def _add_inference_args(parser): group = parser.add_argument_group(title='inference') group.add_argument('--inference-batch-times-seqlen-threshold', type=int, default=512, help='During inference, if batch-size times ' 'sequence-length is smaller than this threshold ' 'then we will not use pipelining, otherwise we will.') return parser def _add_network_size_args(parser): group = parser.add_argument_group(title='network size') Loading megatron/inference/api.py +41 −11 Original line number Diff line number Diff line Loading @@ -26,14 +26,20 @@ from .tokenization import ( detokenize_generations) def generate_and_post_process(model, prompts=None, tokens_to_generate=0, return_output_log_probs=False, return_all_log_probs=False, greedy_sampling=False, top_k_sampling=0, top_p_sampling=0.0, temperature=1.0, add_BOS=False): """TO DO ...""" add_BOS=False, use_eod_token_for_early_termination=True): """Run inferecne and post-process outputs, i.e., detokenize, move to cpu and convert to list.""" # Main inference. tokens, lengths, output_log_probs, all_log_probs = generate( Loading @@ -42,8 +48,12 @@ def generate_and_post_process(model, tokens_to_generate=tokens_to_generate, return_output_log_probs=return_output_log_probs, return_all_log_probs=return_all_log_probs, greedy_sampling=greedy_sampling, top_k_sampling=top_k_sampling, top_p_sampling=top_p_sampling, temperature=temperature, add_BOS=add_BOS) add_BOS=add_BOS, use_eod_token_for_early_termination=use_eod_token_for_early_termination) # Only post-process on first stage. if mpu.is_pipeline_first_stage(): Loading @@ -62,24 +72,42 @@ def generate_and_post_process(model, return None def generate(model, prompts=None, tokens_to_generate=0, return_output_log_probs=False, return_all_log_probs=False, greedy_sampling=False, top_k_sampling=0, top_p_sampling=0.0, temperature=1.0, add_BOS=False): """TO DO ...""" add_BOS=False, use_eod_token_for_early_termination=True): """Given prompts and input parameters, run inference and return: tokens: prompts plus the generated tokens. lengths: length of the prompt + generations. Note that we can discard tokens in the tokens tensor that are after the corresponding length. output_log_probs: log probs of the tokens. all_log_probs: full log probs for all of tokens. """ # Make sure input params are avaialble to all ranks. values = [tokens_to_generate, return_output_log_probs, return_all_log_probs, temperature, add_BOS] values_float_tensor = broadcast_float_list(5, float_list=values) values = [tokens_to_generate, return_output_log_probs, return_all_log_probs, greedy_sampling, top_k_sampling, top_p_sampling, temperature, add_BOS, use_eod_token_for_early_termination] values_float_tensor = broadcast_float_list(9, float_list=values) tokens_to_generate = int(values_float_tensor[0].item()) return_output_log_probs = bool(values_float_tensor[1].item()) return_all_log_probs = bool(values_float_tensor[2].item()) temperature = values_float_tensor[3].item() add_BOS = bool(values_float_tensor[4].item()) greedy_sampling = bool(values_float_tensor[3].item()) top_k_sampling = int(values_float_tensor[4].item()) top_p_sampling = values_float_tensor[5].item() temperature = values_float_tensor[6].item() add_BOS = bool(values_float_tensor[7].item()) use_eod_token_for_early_termination = bool(values_float_tensor[8].item()) # Tokenize prompts and get the batch. # Note that these tensors are broadcaseted to all ranks. Loading @@ -95,4 +123,6 @@ def generate(model, model, context_tokens_tensor, context_length_tensor, return_output_log_probs=return_output_log_probs, return_all_log_probs=return_all_log_probs, temperature=temperature) greedy=greedy_sampling, top_k=top_k_sampling, top_p=top_p_sampling, temperature=temperature, use_eod_token_for_early_termination=use_eod_token_for_early_termination) megatron/inference/communication.py +36 −0 Original line number Diff line number Diff line Loading @@ -21,6 +21,38 @@ import torch from megatron import mpu def recv_from_prev_pipeline_rank_(recv_buffer=None): """Receive from previous pipeline stage and update the input buffer inplace.""" if not mpu.is_pipeline_first_stage(): assert recv_buffer is not None recv_prev_op = torch.distributed.P2POp( torch.distributed.irecv, recv_buffer, mpu.get_pipeline_model_parallel_prev_rank()) reqs = torch.distributed.batch_isend_irecv([recv_prev_op]) for req in reqs: req.wait() # To protect against race condition when using batch_isend_irecv(). torch.cuda.synchronize() def send_to_next_pipeline_rank(tensor=None): """Send output to the next pipeline stage.""" if not mpu.is_pipeline_last_stage(): assert tensor is not None send_next_op = torch.distributed.P2POp( torch.distributed.isend, tensor, mpu.get_pipeline_model_parallel_next_rank()) reqs = torch.distributed.batch_isend_irecv([send_next_op]) for req in reqs: req.wait() # To protect against race condition when using batch_isend_irecv(). torch.cuda.synchronize() def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): """Broadcast a tensor from last pipeline stage to all ranks.""" Loading Loading @@ -96,6 +128,7 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): tensor[...] = tensor_ def broadcast_tensor(size, dtype, tensor=None, rank=0): """ Given size and type of a tensor on all ranks and the tensor value only on a specific rank, broadcast from that rank to all other ranks. Loading @@ -114,6 +147,7 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0): return tensor def broadcast_list(size, dtype, list_values=None, rank=0): """Broadcast a list of values with a given type.""" Loading @@ -125,12 +159,14 @@ def broadcast_list(size, dtype, list_values=None, rank=0): return broadcast_tensor(size, dtype, tensor=tensor, rank=rank) def broadcast_int_list(size, int_list=None, rank=0): """Broadcast a list of interger values.""" return broadcast_list(size, torch.int64, list_values=int_list, rank=rank) def broadcast_float_list(size, float_list=None, rank=0): """Broadcast a list of float values.""" Loading megatron/inference/forward_step.py +47 −32 Original line number Diff line number Diff line Loading @@ -22,14 +22,20 @@ import torch from megatron import ( get_args, mpu) from .communication import ( send_to_next_pipeline_rank, recv_from_prev_pipeline_rank_) class InferenceParams: """Inference parameters that are passed to the main model in order to efficienly calculate and store the context during inference.""" def __init__(self, max_batch_size, max_sequence_len): """Note that offsets are set to zero and we always set the flag to allocate memory. After the first call, make sure to set this flag to False.""" self.max_sequence_len = max_sequence_len self.max_batch_size = max_batch_size self.sequence_len_offset = 0 Loading @@ -39,34 +45,46 @@ class InferenceParams: class ForwardStep: """Forward step function with all the communications. We use a class here to hide the inference parameters from the outside caller.""" def __init__(self, model, max_batch_size, max_sequence_len): """Set values so we don't need to do it multiple times.""" # Make sure model is in eval mode. if isinstance(model, Iterable): for this_model in model: this_model.eval() else: assert not isinstance(model, Iterable), \ 'interleaving schedule is not supported for inference' model.eval() self.model = model self.constant = 512 # Initialize inference parameters. self.inference_params = InferenceParams(max_batch_size, max_sequence_len) # Pipelining arguments. args = get_args() self.pipeline_size_larger_than_one = args.pipeline_model_parallel_size # Threshold of pipelining. self.pipelining_batch_x_seqlen = \ args.inference_batch_times_seqlen_threshold def __call__(self, tokens, position_ids, attention_mask): if tokens.size(0) * tokens.size(1) >= self.constant: micro_batch_size = max(1, self.constant // tokens.size(1)) return _with_pipelining_forward_step(self.model, tokens, """Invocation of the forward methods. Note that self.inference_params is being modified by the forward step.""" # Pipelining case. if self.pipeline_size_larger_than_one: current_batch_x_seqlen = tokens.size(0) * tokens.size(1) if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen: micro_batch_size = \ max(1, self.pipelining_batch_x_seqlen // tokens.size(1)) return _with_pipelining_forward_step(self.model, tokens, position_ids, attention_mask, self.inference_params, micro_batch_size) else: return _no_pipelining_forward_step(self.model, tokens, return _no_pipelining_forward_step(self.model, tokens, position_ids, attention_mask, self.inference_params) Loading Loading @@ -103,9 +121,7 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask, recv_buffer = _allocate_recv_buffer(batch_size, sequence_length) # Receive from previous stage. if not mpu.is_pipeline_first_stage(): torch.distributed.recv(recv_buffer, src=mpu.get_pipeline_model_parallel_prev_rank()) recv_from_prev_pipeline_rank_(recv_buffer) # Forward pass through the model. model.set_input_tensor(recv_buffer) Loading @@ -113,9 +129,7 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask, inference_params=inference_params) # Send output to the next stage. if not mpu.is_pipeline_last_stage(): torch.distributed.send(output_tensor, mpu.get_pipeline_model_parallel_next_rank()) send_to_next_pipeline_rank(output_tensor) # Make sure we do not allocate context memory anymore. if inference_params.allocate_key_value_memory: Loading @@ -128,7 +142,7 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask, def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask, inference_params, recv_buffer=None): """If recv_buffer is none, we will allocate one on the fly.""" # Run a simple forward pass. output_tensor = _forward_step_helper(model, tokens, position_ids, attention_mask, inference_params, Loading @@ -143,9 +157,10 @@ def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask, return logits def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask, inference_params, micro_batch_size): """No interleaving is supported.""" sequence_length = tokens.size(1) batch_size = tokens.size(0) Loading megatron/inference/generation.py +18 −10 Original line number Diff line number Diff line Loading @@ -32,10 +32,12 @@ def generate_tokens_probs_and_return_on_first_stage( model, tokens, lengths, return_output_log_probs=False, return_all_log_probs=False, temperature=1.0): greedy=False, top_k=0, top_p=0.0, temperature=1.0, use_eod_token_for_early_termination=True): """Main token generation function. Arguments: model: XXX model: no interleaving is supported. tokens: prompt tokens extended to be of size [b, max-sequence-length] lengths: original prompt length, size: [b] return_output_log_probs: flag to calculate the log probability of Loading @@ -44,7 +46,14 @@ def generate_tokens_probs_and_return_on_first_stage( return_all_log_probs: flag to calculate the log probability of across all the tokens (vocab size). Note that the log probability is the one after logits are modifed for sampling. greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters. Note that these three paramters are exclusive meaning that: if greedy = true then we should have top-k=top-p=0. if top-k > 0 then we expect greedy=false and top-p=0. if top-p > 0 then we check for greedy=false and top-k=0. temperature: sampling temperature. use_eod_token_for_early_termination: if True, do early termination if all the sequences have reached this token. Note: Outside of model, other parameters only need to be available on rank 0. Outputs: Note that is size is adjusted to a lower value than Loading Loading @@ -108,10 +117,9 @@ def generate_tokens_probs_and_return_on_first_stage( # Run infernece # ============= with torch.no_grad(): attention_mask, position_ids = _build_attention_mask_and_position_ids( tokens) with torch.no_grad(): prev_context_length = 0 for context_length in range(min_prompt_length, max_sequence_length): Loading @@ -132,9 +140,9 @@ def generate_tokens_probs_and_return_on_first_stage( last_token_logits = logits[:, -1, :] new_sample, updated_last_token_logits = sample( last_token_logits, greedy=args.greedy, top_k=args.top_k, top_p=args.top_p, greedy=greedy, top_k=top_k, top_p=top_p, temperature=temperature, vocab_size=tokenizer.vocab_size) # Now that we have the sample and updated logits, Loading Loading @@ -189,8 +197,8 @@ def generate_tokens_probs_and_return_on_first_stage( done = torch.all(is_generation_done) done = broadcast_from_last_pipeline_stage(1, torch.uint8, tensor=done) #if done: # break if use_eod_token_for_early_termination and done: break # =================================================== # Update the length of based on max generated length. Loading Loading
megatron/arguments.py +13 −0 Original line number Diff line number Diff line Loading @@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={}, parser = _add_biencoder_args(parser) parser = _add_vit_args(parser) parser = _add_logging_args(parser) parser = _add_inference_args(parser) # Custom arguments. if extra_args_provider is not None: Loading Loading @@ -273,6 +274,18 @@ def _check_arg_is_not_none(args, arg): assert getattr(args, arg) is not None, '{} argument is None'.format(arg) def _add_inference_args(parser): group = parser.add_argument_group(title='inference') group.add_argument('--inference-batch-times-seqlen-threshold', type=int, default=512, help='During inference, if batch-size times ' 'sequence-length is smaller than this threshold ' 'then we will not use pipelining, otherwise we will.') return parser def _add_network_size_args(parser): group = parser.add_argument_group(title='network size') Loading
megatron/inference/api.py +41 −11 Original line number Diff line number Diff line Loading @@ -26,14 +26,20 @@ from .tokenization import ( detokenize_generations) def generate_and_post_process(model, prompts=None, tokens_to_generate=0, return_output_log_probs=False, return_all_log_probs=False, greedy_sampling=False, top_k_sampling=0, top_p_sampling=0.0, temperature=1.0, add_BOS=False): """TO DO ...""" add_BOS=False, use_eod_token_for_early_termination=True): """Run inferecne and post-process outputs, i.e., detokenize, move to cpu and convert to list.""" # Main inference. tokens, lengths, output_log_probs, all_log_probs = generate( Loading @@ -42,8 +48,12 @@ def generate_and_post_process(model, tokens_to_generate=tokens_to_generate, return_output_log_probs=return_output_log_probs, return_all_log_probs=return_all_log_probs, greedy_sampling=greedy_sampling, top_k_sampling=top_k_sampling, top_p_sampling=top_p_sampling, temperature=temperature, add_BOS=add_BOS) add_BOS=add_BOS, use_eod_token_for_early_termination=use_eod_token_for_early_termination) # Only post-process on first stage. if mpu.is_pipeline_first_stage(): Loading @@ -62,24 +72,42 @@ def generate_and_post_process(model, return None def generate(model, prompts=None, tokens_to_generate=0, return_output_log_probs=False, return_all_log_probs=False, greedy_sampling=False, top_k_sampling=0, top_p_sampling=0.0, temperature=1.0, add_BOS=False): """TO DO ...""" add_BOS=False, use_eod_token_for_early_termination=True): """Given prompts and input parameters, run inference and return: tokens: prompts plus the generated tokens. lengths: length of the prompt + generations. Note that we can discard tokens in the tokens tensor that are after the corresponding length. output_log_probs: log probs of the tokens. all_log_probs: full log probs for all of tokens. """ # Make sure input params are avaialble to all ranks. values = [tokens_to_generate, return_output_log_probs, return_all_log_probs, temperature, add_BOS] values_float_tensor = broadcast_float_list(5, float_list=values) values = [tokens_to_generate, return_output_log_probs, return_all_log_probs, greedy_sampling, top_k_sampling, top_p_sampling, temperature, add_BOS, use_eod_token_for_early_termination] values_float_tensor = broadcast_float_list(9, float_list=values) tokens_to_generate = int(values_float_tensor[0].item()) return_output_log_probs = bool(values_float_tensor[1].item()) return_all_log_probs = bool(values_float_tensor[2].item()) temperature = values_float_tensor[3].item() add_BOS = bool(values_float_tensor[4].item()) greedy_sampling = bool(values_float_tensor[3].item()) top_k_sampling = int(values_float_tensor[4].item()) top_p_sampling = values_float_tensor[5].item() temperature = values_float_tensor[6].item() add_BOS = bool(values_float_tensor[7].item()) use_eod_token_for_early_termination = bool(values_float_tensor[8].item()) # Tokenize prompts and get the batch. # Note that these tensors are broadcaseted to all ranks. Loading @@ -95,4 +123,6 @@ def generate(model, model, context_tokens_tensor, context_length_tensor, return_output_log_probs=return_output_log_probs, return_all_log_probs=return_all_log_probs, temperature=temperature) greedy=greedy_sampling, top_k=top_k_sampling, top_p=top_p_sampling, temperature=temperature, use_eod_token_for_early_termination=use_eod_token_for_early_termination)
megatron/inference/communication.py +36 −0 Original line number Diff line number Diff line Loading @@ -21,6 +21,38 @@ import torch from megatron import mpu def recv_from_prev_pipeline_rank_(recv_buffer=None): """Receive from previous pipeline stage and update the input buffer inplace.""" if not mpu.is_pipeline_first_stage(): assert recv_buffer is not None recv_prev_op = torch.distributed.P2POp( torch.distributed.irecv, recv_buffer, mpu.get_pipeline_model_parallel_prev_rank()) reqs = torch.distributed.batch_isend_irecv([recv_prev_op]) for req in reqs: req.wait() # To protect against race condition when using batch_isend_irecv(). torch.cuda.synchronize() def send_to_next_pipeline_rank(tensor=None): """Send output to the next pipeline stage.""" if not mpu.is_pipeline_last_stage(): assert tensor is not None send_next_op = torch.distributed.P2POp( torch.distributed.isend, tensor, mpu.get_pipeline_model_parallel_next_rank()) reqs = torch.distributed.batch_isend_irecv([send_next_op]) for req in reqs: req.wait() # To protect against race condition when using batch_isend_irecv(). torch.cuda.synchronize() def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): """Broadcast a tensor from last pipeline stage to all ranks.""" Loading Loading @@ -96,6 +128,7 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): tensor[...] = tensor_ def broadcast_tensor(size, dtype, tensor=None, rank=0): """ Given size and type of a tensor on all ranks and the tensor value only on a specific rank, broadcast from that rank to all other ranks. Loading @@ -114,6 +147,7 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0): return tensor def broadcast_list(size, dtype, list_values=None, rank=0): """Broadcast a list of values with a given type.""" Loading @@ -125,12 +159,14 @@ def broadcast_list(size, dtype, list_values=None, rank=0): return broadcast_tensor(size, dtype, tensor=tensor, rank=rank) def broadcast_int_list(size, int_list=None, rank=0): """Broadcast a list of interger values.""" return broadcast_list(size, torch.int64, list_values=int_list, rank=rank) def broadcast_float_list(size, float_list=None, rank=0): """Broadcast a list of float values.""" Loading
megatron/inference/forward_step.py +47 −32 Original line number Diff line number Diff line Loading @@ -22,14 +22,20 @@ import torch from megatron import ( get_args, mpu) from .communication import ( send_to_next_pipeline_rank, recv_from_prev_pipeline_rank_) class InferenceParams: """Inference parameters that are passed to the main model in order to efficienly calculate and store the context during inference.""" def __init__(self, max_batch_size, max_sequence_len): """Note that offsets are set to zero and we always set the flag to allocate memory. After the first call, make sure to set this flag to False.""" self.max_sequence_len = max_sequence_len self.max_batch_size = max_batch_size self.sequence_len_offset = 0 Loading @@ -39,34 +45,46 @@ class InferenceParams: class ForwardStep: """Forward step function with all the communications. We use a class here to hide the inference parameters from the outside caller.""" def __init__(self, model, max_batch_size, max_sequence_len): """Set values so we don't need to do it multiple times.""" # Make sure model is in eval mode. if isinstance(model, Iterable): for this_model in model: this_model.eval() else: assert not isinstance(model, Iterable), \ 'interleaving schedule is not supported for inference' model.eval() self.model = model self.constant = 512 # Initialize inference parameters. self.inference_params = InferenceParams(max_batch_size, max_sequence_len) # Pipelining arguments. args = get_args() self.pipeline_size_larger_than_one = args.pipeline_model_parallel_size # Threshold of pipelining. self.pipelining_batch_x_seqlen = \ args.inference_batch_times_seqlen_threshold def __call__(self, tokens, position_ids, attention_mask): if tokens.size(0) * tokens.size(1) >= self.constant: micro_batch_size = max(1, self.constant // tokens.size(1)) return _with_pipelining_forward_step(self.model, tokens, """Invocation of the forward methods. Note that self.inference_params is being modified by the forward step.""" # Pipelining case. if self.pipeline_size_larger_than_one: current_batch_x_seqlen = tokens.size(0) * tokens.size(1) if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen: micro_batch_size = \ max(1, self.pipelining_batch_x_seqlen // tokens.size(1)) return _with_pipelining_forward_step(self.model, tokens, position_ids, attention_mask, self.inference_params, micro_batch_size) else: return _no_pipelining_forward_step(self.model, tokens, return _no_pipelining_forward_step(self.model, tokens, position_ids, attention_mask, self.inference_params) Loading Loading @@ -103,9 +121,7 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask, recv_buffer = _allocate_recv_buffer(batch_size, sequence_length) # Receive from previous stage. if not mpu.is_pipeline_first_stage(): torch.distributed.recv(recv_buffer, src=mpu.get_pipeline_model_parallel_prev_rank()) recv_from_prev_pipeline_rank_(recv_buffer) # Forward pass through the model. model.set_input_tensor(recv_buffer) Loading @@ -113,9 +129,7 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask, inference_params=inference_params) # Send output to the next stage. if not mpu.is_pipeline_last_stage(): torch.distributed.send(output_tensor, mpu.get_pipeline_model_parallel_next_rank()) send_to_next_pipeline_rank(output_tensor) # Make sure we do not allocate context memory anymore. if inference_params.allocate_key_value_memory: Loading @@ -128,7 +142,7 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask, def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask, inference_params, recv_buffer=None): """If recv_buffer is none, we will allocate one on the fly.""" # Run a simple forward pass. output_tensor = _forward_step_helper(model, tokens, position_ids, attention_mask, inference_params, Loading @@ -143,9 +157,10 @@ def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask, return logits def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask, inference_params, micro_batch_size): """No interleaving is supported.""" sequence_length = tokens.size(1) batch_size = tokens.size(0) Loading
megatron/inference/generation.py +18 −10 Original line number Diff line number Diff line Loading @@ -32,10 +32,12 @@ def generate_tokens_probs_and_return_on_first_stage( model, tokens, lengths, return_output_log_probs=False, return_all_log_probs=False, temperature=1.0): greedy=False, top_k=0, top_p=0.0, temperature=1.0, use_eod_token_for_early_termination=True): """Main token generation function. Arguments: model: XXX model: no interleaving is supported. tokens: prompt tokens extended to be of size [b, max-sequence-length] lengths: original prompt length, size: [b] return_output_log_probs: flag to calculate the log probability of Loading @@ -44,7 +46,14 @@ def generate_tokens_probs_and_return_on_first_stage( return_all_log_probs: flag to calculate the log probability of across all the tokens (vocab size). Note that the log probability is the one after logits are modifed for sampling. greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters. Note that these three paramters are exclusive meaning that: if greedy = true then we should have top-k=top-p=0. if top-k > 0 then we expect greedy=false and top-p=0. if top-p > 0 then we check for greedy=false and top-k=0. temperature: sampling temperature. use_eod_token_for_early_termination: if True, do early termination if all the sequences have reached this token. Note: Outside of model, other parameters only need to be available on rank 0. Outputs: Note that is size is adjusted to a lower value than Loading Loading @@ -108,10 +117,9 @@ def generate_tokens_probs_and_return_on_first_stage( # Run infernece # ============= with torch.no_grad(): attention_mask, position_ids = _build_attention_mask_and_position_ids( tokens) with torch.no_grad(): prev_context_length = 0 for context_length in range(min_prompt_length, max_sequence_length): Loading @@ -132,9 +140,9 @@ def generate_tokens_probs_and_return_on_first_stage( last_token_logits = logits[:, -1, :] new_sample, updated_last_token_logits = sample( last_token_logits, greedy=args.greedy, top_k=args.top_k, top_p=args.top_p, greedy=greedy, top_k=top_k, top_p=top_p, temperature=temperature, vocab_size=tokenizer.vocab_size) # Now that we have the sample and updated logits, Loading Loading @@ -189,8 +197,8 @@ def generate_tokens_probs_and_return_on_first_stage( done = torch.all(is_generation_done) done = broadcast_from_last_pipeline_stage(1, torch.uint8, tensor=done) #if done: # break if use_eod_token_for_early_termination and done: break # =================================================== # Update the length of based on max generated length. Loading