Loading megatron/text_generation_server.py +17 −6 Original line number Diff line number Diff line Loading @@ -15,6 +15,7 @@ import datetime import torch import json import threading from flask import Flask, request, jsonify, current_app from flask_restful import Resource, Api from megatron import get_args Loading @@ -22,6 +23,7 @@ from megatron import mpu from megatron.text_generation_utils import generate GENERATE_NUM = 0 sem = threading.Semaphore() class MegatronGenerate(Resource): def __init__(self, model): Loading @@ -37,6 +39,7 @@ class MegatronGenerate(Resource): print("request IP: " + str(request.remote_addr)) print(json.dumps(request.get_json()),flush=True) print("current time: ", datetime.datetime.now()) sentences = request.get_json()["sentences"] if len(sentences) > 128: return "Maximum number of sentences is 128", 400 Loading @@ -55,8 +58,17 @@ class MegatronGenerate(Resource): if not isinstance(all_probs, bool): return "all_probs must be a boolean value" add_BOS = False if "add_BOS" in request.get_json(): add_BOS = request.get_json()["add_BOS"] if not isinstance(add_BOS, bool): return "add_BOS must be a boolean value" sem.acquire() # Need to get lock to keep multiple threads from hitting code MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs) resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, add_BOS) sem.release() if all_probs: return jsonify({"sentences": resp_sentences, "segments": resp_sentences_seg, Loading @@ -70,8 +82,7 @@ class MegatronGenerate(Resource): class MegatronServer(object): def __init__(self, model): self.app = Flask(__name__, static_folder='static', static_url_path='') self.app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 0 self.app = Flask(__name__) api = Api(self.app) api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model]) Loading megatron/text_generation_utils.py +7 −4 Original line number Diff line number Diff line Loading @@ -95,9 +95,12 @@ def pad_batch(batch, pad_id, max_len): context_lengths.append(context_length) return batch, context_lengths def tokenize_batch(sentences, max_len): def tokenize_batch(sentences, max_len, add_BOS): args = get_args() tokenizer = get_tokenizer() if add_BOS: context_tokens = [[tokenizer.eod] + tokenizer.tokenize(s) for s in sentences] else: context_tokens = [tokenizer.tokenize(s) for s in sentences] context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eod, max_len) Loading Loading @@ -174,10 +177,10 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_ if tokens is not None: return tokens[:, :context_length], output_logits, full_logits def generate(model, sentences=None, tokens_to_generate=0, all_probs=False): def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, add_BOS=False): model.eval() if torch.distributed.get_rank() == 0: context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate) context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate, add_BOS) send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs) else: context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info() Loading Loading
megatron/text_generation_server.py +17 −6 Original line number Diff line number Diff line Loading @@ -15,6 +15,7 @@ import datetime import torch import json import threading from flask import Flask, request, jsonify, current_app from flask_restful import Resource, Api from megatron import get_args Loading @@ -22,6 +23,7 @@ from megatron import mpu from megatron.text_generation_utils import generate GENERATE_NUM = 0 sem = threading.Semaphore() class MegatronGenerate(Resource): def __init__(self, model): Loading @@ -37,6 +39,7 @@ class MegatronGenerate(Resource): print("request IP: " + str(request.remote_addr)) print(json.dumps(request.get_json()),flush=True) print("current time: ", datetime.datetime.now()) sentences = request.get_json()["sentences"] if len(sentences) > 128: return "Maximum number of sentences is 128", 400 Loading @@ -55,8 +58,17 @@ class MegatronGenerate(Resource): if not isinstance(all_probs, bool): return "all_probs must be a boolean value" add_BOS = False if "add_BOS" in request.get_json(): add_BOS = request.get_json()["add_BOS"] if not isinstance(add_BOS, bool): return "add_BOS must be a boolean value" sem.acquire() # Need to get lock to keep multiple threads from hitting code MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs) resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, add_BOS) sem.release() if all_probs: return jsonify({"sentences": resp_sentences, "segments": resp_sentences_seg, Loading @@ -70,8 +82,7 @@ class MegatronGenerate(Resource): class MegatronServer(object): def __init__(self, model): self.app = Flask(__name__, static_folder='static', static_url_path='') self.app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 0 self.app = Flask(__name__) api = Api(self.app) api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model]) Loading
megatron/text_generation_utils.py +7 −4 Original line number Diff line number Diff line Loading @@ -95,9 +95,12 @@ def pad_batch(batch, pad_id, max_len): context_lengths.append(context_length) return batch, context_lengths def tokenize_batch(sentences, max_len): def tokenize_batch(sentences, max_len, add_BOS): args = get_args() tokenizer = get_tokenizer() if add_BOS: context_tokens = [[tokenizer.eod] + tokenizer.tokenize(s) for s in sentences] else: context_tokens = [tokenizer.tokenize(s) for s in sentences] context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eod, max_len) Loading Loading @@ -174,10 +177,10 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_ if tokens is not None: return tokens[:, :context_length], output_logits, full_logits def generate(model, sentences=None, tokens_to_generate=0, all_probs=False): def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, add_BOS=False): model.eval() if torch.distributed.get_rank() == 0: context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate) context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate, add_BOS) send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs) else: context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info() Loading