Loading megatron/api_server.py +3 −2 Original line number Diff line number Diff line Loading @@ -54,12 +54,13 @@ class MegatronGenerate(Resource): return "all_probs must be a boolean value" MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate resp_sentences, resp_sentences_seg, output_logits, full_logits = generate(self.model, sentences, max_len, all_probs) resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, max_len, all_probs) if all_probs: return jsonify({"sentences": resp_sentences, "segments": resp_sentences_seg, "logits": output_logits, "all_logits": full_logits}) "all_logits": full_logits, "tokens": tokens}) return jsonify({"sentences": resp_sentences, "segments": resp_sentences_seg, Loading megatron/text_generation_utils.py +8 −6 Original line number Diff line number Diff line Loading @@ -121,7 +121,7 @@ def receive_generate_info(): """ Needs to be synced up with send_generate_info """ input_info_tensor = torch.empty(3, dtype=torch.int64, device=torch.device("cuda")) input_info_tensor = torch.empty(4, dtype=torch.int64, device=torch.device("cuda")) torch.distributed.broadcast(input_info_tensor, 0) batch_size = input_info_tensor[0].item() seq_len = input_info_tensor[1].item() Loading Loading @@ -166,9 +166,10 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len torch.distributed.broadcast(output_logits, src, group) if all_probs: args = get_args() src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size(), dtype=torch.float32, device=torch.device("cuda")) full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda")) torch.distributed.broadcast(full_logits, src, group) if tokens is not None: Loading @@ -193,8 +194,9 @@ def generate(model, sentences=None, max_len=0, all_probs=False): tokenizer = get_tokenizer() resp_sentences = [] resp_sentences_seg = [] for i in range(decode_tokens.size(0)): decode_token = decode_tokens[i,:].cpu().numpy().tolist() decode_tokens = decode_tokens.cpu().numpy().tolist() for decode_token in decode_tokens: resp_sentences.append(tokenizer.detokenize(decode_token)) words = [] for token in decode_token: Loading @@ -208,8 +210,8 @@ def generate(model, sentences=None, max_len=0, all_probs=False): full_logits = full_logits.cpu().numpy().tolist() end = time.time() print(str(b)+","+str(c)+","+str(decode_tokens.size(1))+","+str(end-start), flush=True) return resp_sentences, resp_sentences_seg, output_logits, full_logits print(str(b)+","+str(c)+","+str(len(decode_tokens[0]))+","+str(end-start), flush=True) return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens def switch(val1, val2, boolean): boolean = boolean.type_as(val1) Loading Loading
megatron/api_server.py +3 −2 Original line number Diff line number Diff line Loading @@ -54,12 +54,13 @@ class MegatronGenerate(Resource): return "all_probs must be a boolean value" MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate resp_sentences, resp_sentences_seg, output_logits, full_logits = generate(self.model, sentences, max_len, all_probs) resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, max_len, all_probs) if all_probs: return jsonify({"sentences": resp_sentences, "segments": resp_sentences_seg, "logits": output_logits, "all_logits": full_logits}) "all_logits": full_logits, "tokens": tokens}) return jsonify({"sentences": resp_sentences, "segments": resp_sentences_seg, Loading
megatron/text_generation_utils.py +8 −6 Original line number Diff line number Diff line Loading @@ -121,7 +121,7 @@ def receive_generate_info(): """ Needs to be synced up with send_generate_info """ input_info_tensor = torch.empty(3, dtype=torch.int64, device=torch.device("cuda")) input_info_tensor = torch.empty(4, dtype=torch.int64, device=torch.device("cuda")) torch.distributed.broadcast(input_info_tensor, 0) batch_size = input_info_tensor[0].item() seq_len = input_info_tensor[1].item() Loading Loading @@ -166,9 +166,10 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len torch.distributed.broadcast(output_logits, src, group) if all_probs: args = get_args() src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size(), dtype=torch.float32, device=torch.device("cuda")) full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda")) torch.distributed.broadcast(full_logits, src, group) if tokens is not None: Loading @@ -193,8 +194,9 @@ def generate(model, sentences=None, max_len=0, all_probs=False): tokenizer = get_tokenizer() resp_sentences = [] resp_sentences_seg = [] for i in range(decode_tokens.size(0)): decode_token = decode_tokens[i,:].cpu().numpy().tolist() decode_tokens = decode_tokens.cpu().numpy().tolist() for decode_token in decode_tokens: resp_sentences.append(tokenizer.detokenize(decode_token)) words = [] for token in decode_token: Loading @@ -208,8 +210,8 @@ def generate(model, sentences=None, max_len=0, all_probs=False): full_logits = full_logits.cpu().numpy().tolist() end = time.time() print(str(b)+","+str(c)+","+str(decode_tokens.size(1))+","+str(end-start), flush=True) return resp_sentences, resp_sentences_seg, output_logits, full_logits print(str(b)+","+str(c)+","+str(len(decode_tokens[0]))+","+str(end-start), flush=True) return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens def switch(val1, val2, boolean): boolean = boolean.type_as(val1) Loading