Loading megatron/inference/api.py 0 → 100644 +56 −0 Original line number Diff line number Diff line # coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference API.""" import torch from .communication import broadcast_float_list from .generation import generate_tokens_probs_and_return_on_first_stage from .tokenization import tokenize_prompts def generate(model, prompts=None, tokens_to_generate=0, return_output_log_probs=False, return_all_log_probs=False, temperature=1.0): """TO DO ...""" # Make sure input params are avaialble to all ranks. values = [tokens_to_generate, return_output_log_probs, return_all_log_probs, temperature] values_float_tensor = broadcast_float_list(4, float_list=values) tokens_to_generate = int(values_float_tensor[0].item()) return_output_log_probs = bool(values_float_tensor[1].item()) return_all_log_probs = bool(values_float_tensor[2].item()) temperature = values_float_tensor[2].item() # Tokenize prompts and get the batch. # Note that these tensors are broadcaseted to all ranks. if torch.distributed.get_rank() == 0: assert prompts is not None context_tokens_tensor, context_length_tensor = tokenize_prompts( prompts=prompts, tokens_to_generate=tokens_to_generate) # Main inference function. # Note that the outputs are available on the first stage. return generate_tokens_probs_and_return_on_first_stage( model, context_tokens_tensor, context_length_tensor, return_output_log_probs=return_output_log_probs, return_all_log_probs=return_all_log_probs, temperature=temperature) megatron/inference/communication.py +39 −8 Original line number Diff line number Diff line Loading @@ -40,6 +40,33 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): return tensor def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None): """Broadcast tensor values from last stage into the first stage.""" # Only first and last stage pipeline stages need to be involved. is_last_stage = mpu.is_pipeline_last_stage() is_first_stage = mpu.is_pipeline_first_stage() if is_last_stage or is_first_stage: if is_last_stage: assert tensor is not None assert tensor.is_cuda assert tensor.is_contiguous() else: tensor = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() # Broadcast from last stage into the first stage. torch.distributed.broadcast(tensor, src, group) else: tensor = None return tensor def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): """Copy tensor values from last stage into the first stage. Note that the input tensor is updated in place.""" Loading @@ -48,11 +75,15 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): is_last_stage = mpu.is_pipeline_last_stage() is_first_stage = mpu.is_pipeline_first_stage() if is_last_stage or is_first_stage: assert tensor is not None assert tensor.is_cuda is_contiguous = tensor.is_contiguous() src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() if is_contiguous: tensor_ = tensor else: if is_last_stage: assert tensor is not None assert tensor.is_cuda tensor_ = tensor.contiguous() else: tensor_ = torch.empty(size, Loading @@ -61,7 +92,7 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): # Broadcast from last stage into the first stage. torch.distributed.broadcast(tensor_, src, group) # Update the first stage tensor if is_first_stage: if is_first_stage and not is_contiguous: tensor[...] = tensor_ Loading megatron/inference/generation.py +104 −33 Original line number Diff line number Diff line Loading @@ -19,19 +19,44 @@ import torch import torch.nn.functional as F from megatron import get_args, get_tokenizer from megatron import mpu from megatron import get_args, get_tokenizer, mpu from megatron.utils import get_ltor_masks_and_position_ids from .communication import ( copy_from_last_to_first_pipeline_stage, broadcast_from_last_pipeline_stage) broadcast_from_last_pipeline_stage, broadcast_from_last_to_first_pipeline_stage) from .forward_step import forward_step from .sampling import sample def generate_tokens(model, tokens, lengths, return_all_probs=False, def generate_tokens_probs_and_return_on_first_stage( model, tokens, lengths, return_output_log_probs=False, return_all_log_probs=False, temperature=1.0): """Main token generation function.""" """Main token generation function. Arguments: model: XXX tokens: prompt tokens extended to be of size [b, max-sequence-length] lengths: original prompt length, size: [b] return_output_log_probs: flag to calculate the log probability of the generated tokens. Note that the log probability is the one after logits are modifed for sampling. return_all_log_probs: flag to calculate the log probability of across all the tokens (vocab size). Note that the log probability is the one after logits are modifed for sampling. temperature: sampling temperature. Note: Outside of model, other parameters only need to be available on rank 0. Outputs: Note that is size is adjusted to a lower value than max-sequence-length if generation is terminated early. tokens: prompt and generated tokens. size: [b, :] generated_sequence_lengths: total length (including prompt) of the generated sequence. size: [b] output_log_probs: log probability of the selected tokens. size: [b, s] all_log_probs: log probability of all the tokens. size: [b, s, vocab-size] """ args = get_args() tokenizer = get_tokenizer() Loading @@ -52,11 +77,24 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False, # Pre-allocate memory # =================== # Log probability of the sequence (prompt + generated tokens) output_log_probs = torch.empty(batch_size, max_sequence_length - 1, # Log probability of the sequence (prompt + generated tokens). output_log_probs = None output_log_probs_size = (batch_size, max_sequence_length - 1) # Log probability of all tokens for the sequence. all_log_probs = None all_log_probs_size = (batch_size, max_sequence_length -1, args.padded_vocab_size) # Lengths of generated seuquence including including prompts. generated_sequence_lengths = None if mpu.is_pipeline_last_stage(): if return_output_log_probs: output_log_probs = torch.empty(output_log_probs_size, dtype=torch.float32, device=torch.cuda.current_device()) if return_all_log_probs: all_log_probs = torch.empty(all_log_probs_size, dtype=torch.float32, device=torch.cuda.current_device()) # Lengths of generated seuquence including including prompts. generated_sequence_lengths = torch.ones( batch_size, dtype=torch.int64, device=torch.cuda.current_device()) * max_sequence_length Loading @@ -64,6 +102,10 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False, is_generation_done = torch.zeros(batch_size, dtype=torch.uint8, device=torch.cuda.current_device()) # ============= # Run infernece # ============= attention_mask, position_ids = _build_attention_mask_and_position_ids( tokens) Loading Loading @@ -114,14 +156,24 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False, tokens[started, context_length] = new_sample[started] # Calculate the log probabilities. if return_output_log_probs or return_all_log_probs: log_probs = F.log_softmax(logits, dim=2) # Pick the tokens that we need to get the log probabilities for. # Note that next input token is the token which we selected in # the current logits, so shift by 1. if return_all_log_probs: all_log_probs[:, prev_context_length:context_length, :] = log_probs if return_output_log_probs: # Pick the tokens that we need to get the log # probabilities for. Note that next input token is # the token which we selected in the current logits, # so shift by 1. indices = torch.unsqueeze( tokens[:, (prev_context_length + 1):(context_length + 1)], tokens[ :, (prev_context_length + 1):(context_length + 1)], 2) output_log_probs[:, prev_context_length:context_length] = \ output_log_probs[:, prev_context_length:context_length] = \ torch.gather(log_probs, 2, indices).squeeze(2) # Update the tokens on the first stage so the next input to Loading @@ -147,17 +199,36 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False, if done: break # =================================================== # Update the length of based on max generated length. # =================================================== tokens = tokens[:, :(context_length + 1)] if mpu.is_pipeline_last_stage(): if return_all_probs: full_logits = None return tokens, generated_sequence_lengths, output_log_probs, \ full_logits, context_length + 1 if return_output_log_probs: output_log_probs = output_log_probs[:, :context_length] if return_all_log_probs: all_log_probs = all_log_probs[:, :context_length, :] # ====================================== # Broadcast to the first pipeline stage. # ====================================== generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage( batch_size, torch.int64, generated_sequence_lengths) if return_output_log_probs: output_log_probs_size = (batch_size, context_length) output_log_probs = broadcast_from_last_to_first_pipeline_stage( output_log_probs_size, torch.float32, output_log_probs) if return_all_log_probs: all_log_probs_size = (batch_size, context_length, args.padded_vocab_size) all_log_probs = broadcast_from_last_to_first_pipeline_stage( all_log_probs_size, torch.float32, all_log_probs) return tokens, generated_sequence_lengths, output_log_probs, \ None, context_length + 1 all_log_probs if mpu.is_pipeline_first_stage(): return tokens, None, None, None, context_length + 1 return None, None, None, None, context_length + 1 def _build_attention_mask_and_position_ids(tokens): Loading megatron/inference/tokenization.py +33 −0 Original line number Diff line number Diff line Loading @@ -23,6 +23,39 @@ from megatron import get_tokenizer from .communication import broadcast_int_list, broadcast_tensor def detokenize_generations(tokens_gpu_tensor, lengths_gpu_tensor, return_segments): """Detokenize the generated tokens.""" tokenizer = get_tokenizer() prompts_plus_generations = [] if return_segments: prompts_plus_generations_segments = [] tokens = tokens_gpu_tensor.cpu().numpy().tolist() lengths = lengths_gpu_tensor.cpu().numpy().tolist() for sequence_tokens, length in zip(tokens, lengths): sequence_tokens = sequence_tokens[:length] prompts_plus_generations.append( tokenizer.detokenize(sequence_tokens)) if return_segments: words = [] for token in sequence_tokens: word = tokenizer.tokenizer.decoder[token] word = bytearray( [tokenizer.tokenizer.byte_decoder[c] for c in word]).decode( 'utf-8', errors='replace') words.append(word) prompts_plus_generations_segments.append(words) if return_segments: return tokens, prompts_plus_generations, \ prompts_plus_generations_segments return tokens, prompts_plus_generations def tokenize_prompts(prompts=None, tokens_to_generate=None, rank=0): """Tokenize prompts and make them avaiable on all ranks.""" Loading megatron/text_generation_utils.py +11 −2 Original line number Diff line number Diff line Loading @@ -153,8 +153,12 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_ if mpu.is_pipeline_last_stage(): src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() print('last rank output size {} {} | \n'.format(output_logits.size(0), output_logits.size(1))) torch.distributed.broadcast(output_logits, src, group) if all_probs: print('last rank full size {} {} | \n'.format(full_logits.size(0), full_logits.size(1), full_logits.size(2))) src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() torch.distributed.broadcast(full_logits, src, group) Loading @@ -164,13 +168,18 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_ src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda")) print('first rank output size {} {} | \n'.format(output_logits.size(0), output_logits.size(1))) torch.distributed.broadcast(output_logits, src, group) if all_probs: args = get_args() src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda")) full_logits = torch.empty(tokens.size(0), context_length-1, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda")) print('first rank full size {} {} | \n'.format(full_logits.size(0), full_logits.size(1), full_logits.size(2))) torch.distributed.broadcast(full_logits, src, group) if tokens is not None: return tokens[:, :context_length], output_logits, full_logits Loading Loading @@ -204,7 +213,7 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe output_logits = output_logits.cpu().numpy().tolist() if all_probs: full_logits = full_logits.cpu().numpy().tolist() full_logits = full_logits.cpu().numpy() #.tolist() return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens Loading Loading
megatron/inference/api.py 0 → 100644 +56 −0 Original line number Diff line number Diff line # coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference API.""" import torch from .communication import broadcast_float_list from .generation import generate_tokens_probs_and_return_on_first_stage from .tokenization import tokenize_prompts def generate(model, prompts=None, tokens_to_generate=0, return_output_log_probs=False, return_all_log_probs=False, temperature=1.0): """TO DO ...""" # Make sure input params are avaialble to all ranks. values = [tokens_to_generate, return_output_log_probs, return_all_log_probs, temperature] values_float_tensor = broadcast_float_list(4, float_list=values) tokens_to_generate = int(values_float_tensor[0].item()) return_output_log_probs = bool(values_float_tensor[1].item()) return_all_log_probs = bool(values_float_tensor[2].item()) temperature = values_float_tensor[2].item() # Tokenize prompts and get the batch. # Note that these tensors are broadcaseted to all ranks. if torch.distributed.get_rank() == 0: assert prompts is not None context_tokens_tensor, context_length_tensor = tokenize_prompts( prompts=prompts, tokens_to_generate=tokens_to_generate) # Main inference function. # Note that the outputs are available on the first stage. return generate_tokens_probs_and_return_on_first_stage( model, context_tokens_tensor, context_length_tensor, return_output_log_probs=return_output_log_probs, return_all_log_probs=return_all_log_probs, temperature=temperature)
megatron/inference/communication.py +39 −8 Original line number Diff line number Diff line Loading @@ -40,6 +40,33 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): return tensor def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None): """Broadcast tensor values from last stage into the first stage.""" # Only first and last stage pipeline stages need to be involved. is_last_stage = mpu.is_pipeline_last_stage() is_first_stage = mpu.is_pipeline_first_stage() if is_last_stage or is_first_stage: if is_last_stage: assert tensor is not None assert tensor.is_cuda assert tensor.is_contiguous() else: tensor = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() # Broadcast from last stage into the first stage. torch.distributed.broadcast(tensor, src, group) else: tensor = None return tensor def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): """Copy tensor values from last stage into the first stage. Note that the input tensor is updated in place.""" Loading @@ -48,11 +75,15 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): is_last_stage = mpu.is_pipeline_last_stage() is_first_stage = mpu.is_pipeline_first_stage() if is_last_stage or is_first_stage: assert tensor is not None assert tensor.is_cuda is_contiguous = tensor.is_contiguous() src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() if is_contiguous: tensor_ = tensor else: if is_last_stage: assert tensor is not None assert tensor.is_cuda tensor_ = tensor.contiguous() else: tensor_ = torch.empty(size, Loading @@ -61,7 +92,7 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): # Broadcast from last stage into the first stage. torch.distributed.broadcast(tensor_, src, group) # Update the first stage tensor if is_first_stage: if is_first_stage and not is_contiguous: tensor[...] = tensor_ Loading
megatron/inference/generation.py +104 −33 Original line number Diff line number Diff line Loading @@ -19,19 +19,44 @@ import torch import torch.nn.functional as F from megatron import get_args, get_tokenizer from megatron import mpu from megatron import get_args, get_tokenizer, mpu from megatron.utils import get_ltor_masks_and_position_ids from .communication import ( copy_from_last_to_first_pipeline_stage, broadcast_from_last_pipeline_stage) broadcast_from_last_pipeline_stage, broadcast_from_last_to_first_pipeline_stage) from .forward_step import forward_step from .sampling import sample def generate_tokens(model, tokens, lengths, return_all_probs=False, def generate_tokens_probs_and_return_on_first_stage( model, tokens, lengths, return_output_log_probs=False, return_all_log_probs=False, temperature=1.0): """Main token generation function.""" """Main token generation function. Arguments: model: XXX tokens: prompt tokens extended to be of size [b, max-sequence-length] lengths: original prompt length, size: [b] return_output_log_probs: flag to calculate the log probability of the generated tokens. Note that the log probability is the one after logits are modifed for sampling. return_all_log_probs: flag to calculate the log probability of across all the tokens (vocab size). Note that the log probability is the one after logits are modifed for sampling. temperature: sampling temperature. Note: Outside of model, other parameters only need to be available on rank 0. Outputs: Note that is size is adjusted to a lower value than max-sequence-length if generation is terminated early. tokens: prompt and generated tokens. size: [b, :] generated_sequence_lengths: total length (including prompt) of the generated sequence. size: [b] output_log_probs: log probability of the selected tokens. size: [b, s] all_log_probs: log probability of all the tokens. size: [b, s, vocab-size] """ args = get_args() tokenizer = get_tokenizer() Loading @@ -52,11 +77,24 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False, # Pre-allocate memory # =================== # Log probability of the sequence (prompt + generated tokens) output_log_probs = torch.empty(batch_size, max_sequence_length - 1, # Log probability of the sequence (prompt + generated tokens). output_log_probs = None output_log_probs_size = (batch_size, max_sequence_length - 1) # Log probability of all tokens for the sequence. all_log_probs = None all_log_probs_size = (batch_size, max_sequence_length -1, args.padded_vocab_size) # Lengths of generated seuquence including including prompts. generated_sequence_lengths = None if mpu.is_pipeline_last_stage(): if return_output_log_probs: output_log_probs = torch.empty(output_log_probs_size, dtype=torch.float32, device=torch.cuda.current_device()) if return_all_log_probs: all_log_probs = torch.empty(all_log_probs_size, dtype=torch.float32, device=torch.cuda.current_device()) # Lengths of generated seuquence including including prompts. generated_sequence_lengths = torch.ones( batch_size, dtype=torch.int64, device=torch.cuda.current_device()) * max_sequence_length Loading @@ -64,6 +102,10 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False, is_generation_done = torch.zeros(batch_size, dtype=torch.uint8, device=torch.cuda.current_device()) # ============= # Run infernece # ============= attention_mask, position_ids = _build_attention_mask_and_position_ids( tokens) Loading Loading @@ -114,14 +156,24 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False, tokens[started, context_length] = new_sample[started] # Calculate the log probabilities. if return_output_log_probs or return_all_log_probs: log_probs = F.log_softmax(logits, dim=2) # Pick the tokens that we need to get the log probabilities for. # Note that next input token is the token which we selected in # the current logits, so shift by 1. if return_all_log_probs: all_log_probs[:, prev_context_length:context_length, :] = log_probs if return_output_log_probs: # Pick the tokens that we need to get the log # probabilities for. Note that next input token is # the token which we selected in the current logits, # so shift by 1. indices = torch.unsqueeze( tokens[:, (prev_context_length + 1):(context_length + 1)], tokens[ :, (prev_context_length + 1):(context_length + 1)], 2) output_log_probs[:, prev_context_length:context_length] = \ output_log_probs[:, prev_context_length:context_length] = \ torch.gather(log_probs, 2, indices).squeeze(2) # Update the tokens on the first stage so the next input to Loading @@ -147,17 +199,36 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False, if done: break # =================================================== # Update the length of based on max generated length. # =================================================== tokens = tokens[:, :(context_length + 1)] if mpu.is_pipeline_last_stage(): if return_all_probs: full_logits = None return tokens, generated_sequence_lengths, output_log_probs, \ full_logits, context_length + 1 if return_output_log_probs: output_log_probs = output_log_probs[:, :context_length] if return_all_log_probs: all_log_probs = all_log_probs[:, :context_length, :] # ====================================== # Broadcast to the first pipeline stage. # ====================================== generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage( batch_size, torch.int64, generated_sequence_lengths) if return_output_log_probs: output_log_probs_size = (batch_size, context_length) output_log_probs = broadcast_from_last_to_first_pipeline_stage( output_log_probs_size, torch.float32, output_log_probs) if return_all_log_probs: all_log_probs_size = (batch_size, context_length, args.padded_vocab_size) all_log_probs = broadcast_from_last_to_first_pipeline_stage( all_log_probs_size, torch.float32, all_log_probs) return tokens, generated_sequence_lengths, output_log_probs, \ None, context_length + 1 all_log_probs if mpu.is_pipeline_first_stage(): return tokens, None, None, None, context_length + 1 return None, None, None, None, context_length + 1 def _build_attention_mask_and_position_ids(tokens): Loading
megatron/inference/tokenization.py +33 −0 Original line number Diff line number Diff line Loading @@ -23,6 +23,39 @@ from megatron import get_tokenizer from .communication import broadcast_int_list, broadcast_tensor def detokenize_generations(tokens_gpu_tensor, lengths_gpu_tensor, return_segments): """Detokenize the generated tokens.""" tokenizer = get_tokenizer() prompts_plus_generations = [] if return_segments: prompts_plus_generations_segments = [] tokens = tokens_gpu_tensor.cpu().numpy().tolist() lengths = lengths_gpu_tensor.cpu().numpy().tolist() for sequence_tokens, length in zip(tokens, lengths): sequence_tokens = sequence_tokens[:length] prompts_plus_generations.append( tokenizer.detokenize(sequence_tokens)) if return_segments: words = [] for token in sequence_tokens: word = tokenizer.tokenizer.decoder[token] word = bytearray( [tokenizer.tokenizer.byte_decoder[c] for c in word]).decode( 'utf-8', errors='replace') words.append(word) prompts_plus_generations_segments.append(words) if return_segments: return tokens, prompts_plus_generations, \ prompts_plus_generations_segments return tokens, prompts_plus_generations def tokenize_prompts(prompts=None, tokens_to_generate=None, rank=0): """Tokenize prompts and make them avaiable on all ranks.""" Loading
megatron/text_generation_utils.py +11 −2 Original line number Diff line number Diff line Loading @@ -153,8 +153,12 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_ if mpu.is_pipeline_last_stage(): src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() print('last rank output size {} {} | \n'.format(output_logits.size(0), output_logits.size(1))) torch.distributed.broadcast(output_logits, src, group) if all_probs: print('last rank full size {} {} | \n'.format(full_logits.size(0), full_logits.size(1), full_logits.size(2))) src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() torch.distributed.broadcast(full_logits, src, group) Loading @@ -164,13 +168,18 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_ src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda")) print('first rank output size {} {} | \n'.format(output_logits.size(0), output_logits.size(1))) torch.distributed.broadcast(output_logits, src, group) if all_probs: args = get_args() src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda")) full_logits = torch.empty(tokens.size(0), context_length-1, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda")) print('first rank full size {} {} | \n'.format(full_logits.size(0), full_logits.size(1), full_logits.size(2))) torch.distributed.broadcast(full_logits, src, group) if tokens is not None: return tokens[:, :context_length], output_logits, full_logits Loading Loading @@ -204,7 +213,7 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe output_logits = output_logits.cpu().numpy().tolist() if all_probs: full_logits = full_logits.cpu().numpy().tolist() full_logits = full_logits.cpu().numpy() #.tolist() return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens Loading