Loading megatron/text_generation/api.py +10 −19 Original line number Diff line number Diff line Loading @@ -31,7 +31,6 @@ def generate_and_post_process(model, prompts=None, tokens_to_generate=0, return_output_log_probs=False, return_all_log_probs=False, greedy_sampling=False, top_k_sampling=0, top_p_sampling=0.0, Loading @@ -42,12 +41,11 @@ def generate_and_post_process(model, move to cpu and convert to list.""" # Main inference. tokens, lengths, output_log_probs, all_log_probs = generate( tokens, lengths, output_log_probs = generate( model, prompts=prompts, tokens_to_generate=tokens_to_generate, return_output_log_probs=return_output_log_probs, return_all_log_probs=return_all_log_probs, greedy_sampling=greedy_sampling, top_k_sampling=top_k_sampling, top_p_sampling=top_p_sampling, Loading @@ -63,11 +61,9 @@ def generate_and_post_process(model, if return_output_log_probs: output_log_probs = output_log_probs.cpu().numpy().tolist() if return_all_log_probs: all_log_probs = all_log_probs.cpu().numpy().tolist() return prompts_plus_generations, prompts_plus_generations_segments, \ output_log_probs, all_log_probs, tokens output_log_probs, tokens return None Loading @@ -77,7 +73,6 @@ def generate(model, prompts=None, tokens_to_generate=0, return_output_log_probs=False, return_all_log_probs=False, greedy_sampling=False, top_k_sampling=0, top_p_sampling=0.0, Loading @@ -90,24 +85,21 @@ def generate(model, discard tokens in the tokens tensor that are after the corresponding length. output_log_probs: log probs of the tokens. all_log_probs: full log probs for all of tokens. """ # Make sure input params are avaialble to all ranks. values = [tokens_to_generate, return_output_log_probs, return_all_log_probs, values = [tokens_to_generate, return_output_log_probs, greedy_sampling, top_k_sampling, top_p_sampling, temperature, add_BOS, use_eod_token_for_early_termination] values_float_tensor = broadcast_float_list(9, float_list=values) values_float_tensor = broadcast_float_list(8, float_list=values) tokens_to_generate = int(values_float_tensor[0].item()) return_output_log_probs = bool(values_float_tensor[1].item()) return_all_log_probs = bool(values_float_tensor[2].item()) greedy_sampling = bool(values_float_tensor[3].item()) top_k_sampling = int(values_float_tensor[4].item()) top_p_sampling = values_float_tensor[5].item() temperature = values_float_tensor[6].item() add_BOS = bool(values_float_tensor[7].item()) use_eod_token_for_early_termination = bool(values_float_tensor[8].item()) greedy_sampling = bool(values_float_tensor[2].item()) top_k_sampling = int(values_float_tensor[3].item()) top_p_sampling = values_float_tensor[4].item() temperature = values_float_tensor[5].item() add_BOS = bool(values_float_tensor[6].item()) use_eod_token_for_early_termination = bool(values_float_tensor[7].item()) # Tokenize prompts and get the batch. # Note that these tensors are broadcaseted to all ranks. Loading @@ -122,7 +114,6 @@ def generate(model, return generate_tokens_probs_and_return_on_first_stage( model, context_tokens_tensor, context_length_tensor, return_output_log_probs=return_output_log_probs, return_all_log_probs=return_all_log_probs, greedy=greedy_sampling, top_k=top_k_sampling, top_p=top_p_sampling, temperature=temperature, use_eod_token_for_early_termination=use_eod_token_for_early_termination) megatron/text_generation/communication.py +35 −13 Original line number Diff line number Diff line Loading @@ -55,13 +55,31 @@ def send_to_next_pipeline_rank(tensor=None): def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): """Broadcast a tensor from last pipeline stage to all ranks.""" if mpu.is_pipeline_last_stage(): def _is_cuda(tensor): """Check if a tensor is not none and is cuda.""" assert tensor is not None assert tensor.is_cuda def _is_cuda_contiguous(tensor): """Check if a tensor is not none, is cuda, and is contiguous.""" _is_cuda(tensor) assert tensor.is_contiguous() def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): """Broadcast a tensor from last pipeline stage to all ranks.""" is_last_stage = mpu.is_pipeline_last_stage() # If first stage and last state are the same, then there is no # pipeline parallelism and no need to communicate. if mpu.is_pipeline_first_stage() and is_last_stage: return tensor if is_last_stage: _is_cuda_contiguous(tensor) else: tensor = torch.empty(size, dtype=dtype, Loading @@ -78,14 +96,16 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None): """Broadcast tensor values from last stage into the first stage.""" # Only first and last stage pipeline stages need to be involved. is_last_stage = mpu.is_pipeline_last_stage() is_first_stage = mpu.is_pipeline_first_stage() # If first stage and last state are the same, then there is no # pipeline parallelism and no need to communicate. if is_first_stage and is_last_stage: return tensor # Only first and last stage pipeline stages need to be involved. if is_last_stage or is_first_stage: if is_last_stage: assert tensor is not None assert tensor.is_cuda assert tensor.is_contiguous() _is_cuda_contiguous(tensor) else: tensor = torch.empty(size, dtype=dtype, Loading @@ -105,12 +125,15 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): """Copy tensor values from last stage into the first stage. Note that the input tensor is updated in place.""" # Only first and last stage pipeline stages need to be involved. is_last_stage = mpu.is_pipeline_last_stage() is_first_stage = mpu.is_pipeline_first_stage() # If first stage and last state are the same, then there is no # pipeline parallelism and no need to communicate. if is_first_stage and is_last_stage: return # Only first and last stage pipeline stages need to be involved. if is_last_stage or is_first_stage: assert tensor is not None assert tensor.is_cuda _is_cuda(tensor) is_contiguous = tensor.is_contiguous() src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() Loading @@ -137,8 +160,7 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0): """ if torch.distributed.get_rank() == rank: assert tensor is not None assert tensor.is_cuda _is_cuda_contiguous(tensor) else: tensor = torch.empty(size, dtype=dtype, Loading megatron/text_generation/generation.py +3 −29 Original line number Diff line number Diff line Loading @@ -31,7 +31,6 @@ from .sampling import sample def generate_tokens_probs_and_return_on_first_stage( model, tokens, lengths, return_output_log_probs=False, return_all_log_probs=False, greedy=False, top_k=0, top_p=0.0, temperature=1.0, use_eod_token_for_early_termination=True): Loading @@ -43,9 +42,6 @@ def generate_tokens_probs_and_return_on_first_stage( return_output_log_probs: flag to calculate the log probability of the generated tokens. Note that the log probability is the one after logits are modifed for sampling. return_all_log_probs: flag to calculate the log probability of across all the tokens (vocab size). Note that the log probability is the one after logits are modifed for sampling. greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters. Note that these three paramters are exclusive meaning that: if greedy = true then we should have top-k=top-p=0. Loading @@ -62,8 +58,6 @@ def generate_tokens_probs_and_return_on_first_stage( generated_sequence_lengths: total length (including prompt) of the generated sequence. size: [b] output_log_probs: log probability of the selected tokens. size: [b, s] all_log_probs: log probability of all the tokens. size: [b, s, vocab-size] """ args = get_args() Loading Loading @@ -91,10 +85,6 @@ def generate_tokens_probs_and_return_on_first_stage( # Log probability of the sequence (prompt + generated tokens). output_log_probs = None output_log_probs_size = (batch_size, max_sequence_length - 1) # Log probability of all tokens for the sequence. all_log_probs = None all_log_probs_size = (batch_size, max_sequence_length -1, args.padded_vocab_size) # Lengths of generated seuquence including including prompts. generated_sequence_lengths = None if mpu.is_pipeline_last_stage(): Loading @@ -102,10 +92,6 @@ def generate_tokens_probs_and_return_on_first_stage( output_log_probs = torch.empty(output_log_probs_size, dtype=torch.float32, device=torch.cuda.current_device()) if return_all_log_probs: all_log_probs = torch.empty(all_log_probs_size, dtype=torch.float32, device=torch.cuda.current_device()) generated_sequence_lengths = torch.ones( batch_size, dtype=torch.int64, device=torch.cuda.current_device()) * max_sequence_length Loading Loading @@ -157,12 +143,8 @@ def generate_tokens_probs_and_return_on_first_stage( tokens[started, context_length] = new_sample[started] # Calculate the log probabilities. if return_output_log_probs or return_all_log_probs: if return_output_log_probs: log_probs = F.log_softmax(logits, dim=2) if return_all_log_probs: all_log_probs[:, prev_context_length:context_length, :] = log_probs if return_output_log_probs: # Pick the tokens that we need to get the log # probabilities for. Note that next input token is Loading Loading @@ -208,8 +190,6 @@ def generate_tokens_probs_and_return_on_first_stage( if mpu.is_pipeline_last_stage(): if return_output_log_probs: output_log_probs = output_log_probs[:, :context_length] if return_all_log_probs: all_log_probs = all_log_probs[:, :context_length, :] # ====================================== # Broadcast to the first pipeline stage. Loading @@ -221,14 +201,8 @@ def generate_tokens_probs_and_return_on_first_stage( output_log_probs_size = (batch_size, context_length) output_log_probs = broadcast_from_last_to_first_pipeline_stage( output_log_probs_size, torch.float32, output_log_probs) if return_all_log_probs: all_log_probs_size = (batch_size, context_length, args.padded_vocab_size) all_log_probs = broadcast_from_last_to_first_pipeline_stage( all_log_probs_size, torch.float32, all_log_probs) return tokens, generated_sequence_lengths, output_log_probs, \ all_log_probs return tokens, generated_sequence_lengths, output_log_probs Loading megatron/text_generation_server.py +1 −2 Original line number Diff line number Diff line Loading @@ -101,13 +101,12 @@ class MegatronGenerate(Resource): with lock: # Need to get lock to keep multiple threads from hitting code MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate response, response_seg, response_logprobs, _, _ = \ response, response_seg, response_logprobs, _ = \ generate_and_post_process( self.model, prompts=prompts, tokens_to_generate=tokens_to_generate, return_output_log_probs=logprobs, return_all_log_probs=False, greedy_sampling=args.greedy, top_k_sampling=top_k, top_p_sampling=top_p, Loading Loading
megatron/text_generation/api.py +10 −19 Original line number Diff line number Diff line Loading @@ -31,7 +31,6 @@ def generate_and_post_process(model, prompts=None, tokens_to_generate=0, return_output_log_probs=False, return_all_log_probs=False, greedy_sampling=False, top_k_sampling=0, top_p_sampling=0.0, Loading @@ -42,12 +41,11 @@ def generate_and_post_process(model, move to cpu and convert to list.""" # Main inference. tokens, lengths, output_log_probs, all_log_probs = generate( tokens, lengths, output_log_probs = generate( model, prompts=prompts, tokens_to_generate=tokens_to_generate, return_output_log_probs=return_output_log_probs, return_all_log_probs=return_all_log_probs, greedy_sampling=greedy_sampling, top_k_sampling=top_k_sampling, top_p_sampling=top_p_sampling, Loading @@ -63,11 +61,9 @@ def generate_and_post_process(model, if return_output_log_probs: output_log_probs = output_log_probs.cpu().numpy().tolist() if return_all_log_probs: all_log_probs = all_log_probs.cpu().numpy().tolist() return prompts_plus_generations, prompts_plus_generations_segments, \ output_log_probs, all_log_probs, tokens output_log_probs, tokens return None Loading @@ -77,7 +73,6 @@ def generate(model, prompts=None, tokens_to_generate=0, return_output_log_probs=False, return_all_log_probs=False, greedy_sampling=False, top_k_sampling=0, top_p_sampling=0.0, Loading @@ -90,24 +85,21 @@ def generate(model, discard tokens in the tokens tensor that are after the corresponding length. output_log_probs: log probs of the tokens. all_log_probs: full log probs for all of tokens. """ # Make sure input params are avaialble to all ranks. values = [tokens_to_generate, return_output_log_probs, return_all_log_probs, values = [tokens_to_generate, return_output_log_probs, greedy_sampling, top_k_sampling, top_p_sampling, temperature, add_BOS, use_eod_token_for_early_termination] values_float_tensor = broadcast_float_list(9, float_list=values) values_float_tensor = broadcast_float_list(8, float_list=values) tokens_to_generate = int(values_float_tensor[0].item()) return_output_log_probs = bool(values_float_tensor[1].item()) return_all_log_probs = bool(values_float_tensor[2].item()) greedy_sampling = bool(values_float_tensor[3].item()) top_k_sampling = int(values_float_tensor[4].item()) top_p_sampling = values_float_tensor[5].item() temperature = values_float_tensor[6].item() add_BOS = bool(values_float_tensor[7].item()) use_eod_token_for_early_termination = bool(values_float_tensor[8].item()) greedy_sampling = bool(values_float_tensor[2].item()) top_k_sampling = int(values_float_tensor[3].item()) top_p_sampling = values_float_tensor[4].item() temperature = values_float_tensor[5].item() add_BOS = bool(values_float_tensor[6].item()) use_eod_token_for_early_termination = bool(values_float_tensor[7].item()) # Tokenize prompts and get the batch. # Note that these tensors are broadcaseted to all ranks. Loading @@ -122,7 +114,6 @@ def generate(model, return generate_tokens_probs_and_return_on_first_stage( model, context_tokens_tensor, context_length_tensor, return_output_log_probs=return_output_log_probs, return_all_log_probs=return_all_log_probs, greedy=greedy_sampling, top_k=top_k_sampling, top_p=top_p_sampling, temperature=temperature, use_eod_token_for_early_termination=use_eod_token_for_early_termination)
megatron/text_generation/communication.py +35 −13 Original line number Diff line number Diff line Loading @@ -55,13 +55,31 @@ def send_to_next_pipeline_rank(tensor=None): def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): """Broadcast a tensor from last pipeline stage to all ranks.""" if mpu.is_pipeline_last_stage(): def _is_cuda(tensor): """Check if a tensor is not none and is cuda.""" assert tensor is not None assert tensor.is_cuda def _is_cuda_contiguous(tensor): """Check if a tensor is not none, is cuda, and is contiguous.""" _is_cuda(tensor) assert tensor.is_contiguous() def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): """Broadcast a tensor from last pipeline stage to all ranks.""" is_last_stage = mpu.is_pipeline_last_stage() # If first stage and last state are the same, then there is no # pipeline parallelism and no need to communicate. if mpu.is_pipeline_first_stage() and is_last_stage: return tensor if is_last_stage: _is_cuda_contiguous(tensor) else: tensor = torch.empty(size, dtype=dtype, Loading @@ -78,14 +96,16 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None): """Broadcast tensor values from last stage into the first stage.""" # Only first and last stage pipeline stages need to be involved. is_last_stage = mpu.is_pipeline_last_stage() is_first_stage = mpu.is_pipeline_first_stage() # If first stage and last state are the same, then there is no # pipeline parallelism and no need to communicate. if is_first_stage and is_last_stage: return tensor # Only first and last stage pipeline stages need to be involved. if is_last_stage or is_first_stage: if is_last_stage: assert tensor is not None assert tensor.is_cuda assert tensor.is_contiguous() _is_cuda_contiguous(tensor) else: tensor = torch.empty(size, dtype=dtype, Loading @@ -105,12 +125,15 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): """Copy tensor values from last stage into the first stage. Note that the input tensor is updated in place.""" # Only first and last stage pipeline stages need to be involved. is_last_stage = mpu.is_pipeline_last_stage() is_first_stage = mpu.is_pipeline_first_stage() # If first stage and last state are the same, then there is no # pipeline parallelism and no need to communicate. if is_first_stage and is_last_stage: return # Only first and last stage pipeline stages need to be involved. if is_last_stage or is_first_stage: assert tensor is not None assert tensor.is_cuda _is_cuda(tensor) is_contiguous = tensor.is_contiguous() src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() Loading @@ -137,8 +160,7 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0): """ if torch.distributed.get_rank() == rank: assert tensor is not None assert tensor.is_cuda _is_cuda_contiguous(tensor) else: tensor = torch.empty(size, dtype=dtype, Loading
megatron/text_generation/generation.py +3 −29 Original line number Diff line number Diff line Loading @@ -31,7 +31,6 @@ from .sampling import sample def generate_tokens_probs_and_return_on_first_stage( model, tokens, lengths, return_output_log_probs=False, return_all_log_probs=False, greedy=False, top_k=0, top_p=0.0, temperature=1.0, use_eod_token_for_early_termination=True): Loading @@ -43,9 +42,6 @@ def generate_tokens_probs_and_return_on_first_stage( return_output_log_probs: flag to calculate the log probability of the generated tokens. Note that the log probability is the one after logits are modifed for sampling. return_all_log_probs: flag to calculate the log probability of across all the tokens (vocab size). Note that the log probability is the one after logits are modifed for sampling. greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters. Note that these three paramters are exclusive meaning that: if greedy = true then we should have top-k=top-p=0. Loading @@ -62,8 +58,6 @@ def generate_tokens_probs_and_return_on_first_stage( generated_sequence_lengths: total length (including prompt) of the generated sequence. size: [b] output_log_probs: log probability of the selected tokens. size: [b, s] all_log_probs: log probability of all the tokens. size: [b, s, vocab-size] """ args = get_args() Loading Loading @@ -91,10 +85,6 @@ def generate_tokens_probs_and_return_on_first_stage( # Log probability of the sequence (prompt + generated tokens). output_log_probs = None output_log_probs_size = (batch_size, max_sequence_length - 1) # Log probability of all tokens for the sequence. all_log_probs = None all_log_probs_size = (batch_size, max_sequence_length -1, args.padded_vocab_size) # Lengths of generated seuquence including including prompts. generated_sequence_lengths = None if mpu.is_pipeline_last_stage(): Loading @@ -102,10 +92,6 @@ def generate_tokens_probs_and_return_on_first_stage( output_log_probs = torch.empty(output_log_probs_size, dtype=torch.float32, device=torch.cuda.current_device()) if return_all_log_probs: all_log_probs = torch.empty(all_log_probs_size, dtype=torch.float32, device=torch.cuda.current_device()) generated_sequence_lengths = torch.ones( batch_size, dtype=torch.int64, device=torch.cuda.current_device()) * max_sequence_length Loading Loading @@ -157,12 +143,8 @@ def generate_tokens_probs_and_return_on_first_stage( tokens[started, context_length] = new_sample[started] # Calculate the log probabilities. if return_output_log_probs or return_all_log_probs: if return_output_log_probs: log_probs = F.log_softmax(logits, dim=2) if return_all_log_probs: all_log_probs[:, prev_context_length:context_length, :] = log_probs if return_output_log_probs: # Pick the tokens that we need to get the log # probabilities for. Note that next input token is Loading Loading @@ -208,8 +190,6 @@ def generate_tokens_probs_and_return_on_first_stage( if mpu.is_pipeline_last_stage(): if return_output_log_probs: output_log_probs = output_log_probs[:, :context_length] if return_all_log_probs: all_log_probs = all_log_probs[:, :context_length, :] # ====================================== # Broadcast to the first pipeline stage. Loading @@ -221,14 +201,8 @@ def generate_tokens_probs_and_return_on_first_stage( output_log_probs_size = (batch_size, context_length) output_log_probs = broadcast_from_last_to_first_pipeline_stage( output_log_probs_size, torch.float32, output_log_probs) if return_all_log_probs: all_log_probs_size = (batch_size, context_length, args.padded_vocab_size) all_log_probs = broadcast_from_last_to_first_pipeline_stage( all_log_probs_size, torch.float32, all_log_probs) return tokens, generated_sequence_lengths, output_log_probs, \ all_log_probs return tokens, generated_sequence_lengths, output_log_probs Loading
megatron/text_generation_server.py +1 −2 Original line number Diff line number Diff line Loading @@ -101,13 +101,12 @@ class MegatronGenerate(Resource): with lock: # Need to get lock to keep multiple threads from hitting code MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate response, response_seg, response_logprobs, _, _ = \ response, response_seg, response_logprobs, _ = \ generate_and_post_process( self.model, prompts=prompts, tokens_to_generate=tokens_to_generate, return_output_log_probs=logprobs, return_all_log_probs=False, greedy_sampling=args.greedy, top_k_sampling=top_k, top_p_sampling=top_p, Loading