Loading megatron/text_generation_server.py +16 −15 Original line number Diff line number Diff line Loading @@ -12,10 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime import torch import json from flask import Flask, request, jsonify, current_app from flask_restful import Resource, Api from megatron import get_args from megatron import mpu from megatron.text_generation_utils import generate Loading @@ -35,17 +36,20 @@ class MegatronGenerate(Resource): def put(self): args = get_args() print("request IP: " + str(request.remote_addr)) print(json.dumps(request.get_json()),flush=True) print("current time: ", datetime.datetime.now()) sentences = request.get_json()["sentences"] if len(sentences) > 128: return "Maximum number of sentences is 128", 400 max_len = 64 # Choosing hopefully sane default. Full sequence is slow if "max_len" in request.get_json(): max_len = request.get_json()["max_len"] if not isinstance(max_len, int): return "max_len must be an integer greater than 0" if max_len < 1: return "max_len must be an integer greater than 0" tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow if "tokens_to_generate" in request.get_json(): tokens_to_generate = request.get_json()["tokens_to_generate"] if not isinstance(tokens_to_generate, int): return "tokens_to_generate must be an integer greater than 0" if tokens_to_generate < 1: return "tokens_to_generate must be an integer greater than 0" all_probs = False if "all_probs" in request.get_json(): Loading @@ -54,7 +58,7 @@ class MegatronGenerate(Resource): return "all_probs must be a boolean value" MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, max_len, all_probs) resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs) if all_probs: return jsonify({"sentences": resp_sentences, "segments": resp_sentences_seg, Loading @@ -66,15 +70,12 @@ class MegatronGenerate(Resource): "segments": resp_sentences_seg, "logits": output_logits}) def index(): return current_app.send_static_file('index.html') class MegatronServer(object): def __init__(self, model): self.app = Flask(__name__) self.app.add_url_rule('/', 'index', index) self.app = Flask(__name__, static_folder='static', static_url_path='') self.app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 0 api = Api(self.app) api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model]) def run(self, url): self.app.run(url, threaded=False, debug=False) self.app.run(url, threaded=True, debug=False) megatron/text_generation_utils.py +13 −15 Original line number Diff line number Diff line Loading @@ -104,12 +104,12 @@ def tokenize_batch(sentences): context_length_tensor = torch.cuda.LongTensor(context_lengths) return context_tokens_tensor, context_length_tensor def send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs): def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs): """ Needs to be synced up with receive_generate_info """ # Send the sizes of the tensors input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), max_len, all_probs] input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, all_probs] input_info_tensor = torch.cuda.LongTensor(input_info) torch.distributed.broadcast(input_info_tensor, 0) Loading @@ -125,7 +125,7 @@ def receive_generate_info(): torch.distributed.broadcast(input_info_tensor, 0) batch_size = input_info_tensor[0].item() seq_len = input_info_tensor[1].item() max_len = input_info_tensor[2].item() tokens_to_generate = input_info_tensor[2].item() all_probs = input_info_tensor[3].item() context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device()) Loading @@ -135,16 +135,16 @@ def receive_generate_info(): torch.distributed.broadcast(context_length_tensor, 0) torch.distributed.broadcast(context_tokens_tensor, 0) return context_length_tensor, context_tokens_tensor, max_len, all_probs return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len, all_probs): def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs): context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, max_len, tokens_to_generate, all_probs) for tokens, lengths, output_logits, full_logits in batch_token_iterator: context_length += 1 Loading Loading @@ -175,15 +175,15 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len if tokens is not None: return tokens[:, :context_length], output_logits, full_logits def generate(model, sentences=None, max_len=0, all_probs=False): def generate(model, sentences=None, tokens_to_generate=0, all_probs=False): model.eval() if torch.distributed.get_rank() == 0: context_tokens_tensor, context_length_tensor = tokenize_batch(sentences) send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs) send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs) else: context_length_tensor, context_tokens_tensor, max_len, all_probs = receive_generate_info() context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info() output = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len, all_probs) output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs) if output is not None: decode_tokens, output_logits, full_logits = output Loading Loading @@ -264,7 +264,7 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, maxlen, all_probs=False, type_ids=None): tokens_to_generate, all_probs=False, type_ids=None): args = get_args() tokenizer = get_tokenizer() Loading @@ -280,7 +280,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths, eos_id = tokenizer.eod counter = 0 org_context_length = context_length layer_past = None batch_size = context_tokens.size(0) Loading @@ -288,8 +287,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths, tokens = context_tokens output_logits = None # TODO(rprenger) maxlen should be named a different parameter maxlen = maxlen + org_context_length # Generate enough tokens for the longest sequence maxlen = tokens_to_generate + context_lengths.max().item() # TODO(rprenger) Need a better understanding of what args.seq_length vs args.out_seq_length (shouldn't be "args") if maxlen > args.seq_length: Loading Loading @@ -357,7 +356,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths, if all_probs: full_logits = torch.cat([full_logits, output_context], 1) #output_logits = torch.cat([output_logits, output[:,context_length,new_tokens]], 1) src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() torch.distributed.broadcast(new_tokens, src, group) Loading Loading
megatron/text_generation_server.py +16 −15 Original line number Diff line number Diff line Loading @@ -12,10 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime import torch import json from flask import Flask, request, jsonify, current_app from flask_restful import Resource, Api from megatron import get_args from megatron import mpu from megatron.text_generation_utils import generate Loading @@ -35,17 +36,20 @@ class MegatronGenerate(Resource): def put(self): args = get_args() print("request IP: " + str(request.remote_addr)) print(json.dumps(request.get_json()),flush=True) print("current time: ", datetime.datetime.now()) sentences = request.get_json()["sentences"] if len(sentences) > 128: return "Maximum number of sentences is 128", 400 max_len = 64 # Choosing hopefully sane default. Full sequence is slow if "max_len" in request.get_json(): max_len = request.get_json()["max_len"] if not isinstance(max_len, int): return "max_len must be an integer greater than 0" if max_len < 1: return "max_len must be an integer greater than 0" tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow if "tokens_to_generate" in request.get_json(): tokens_to_generate = request.get_json()["tokens_to_generate"] if not isinstance(tokens_to_generate, int): return "tokens_to_generate must be an integer greater than 0" if tokens_to_generate < 1: return "tokens_to_generate must be an integer greater than 0" all_probs = False if "all_probs" in request.get_json(): Loading @@ -54,7 +58,7 @@ class MegatronGenerate(Resource): return "all_probs must be a boolean value" MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, max_len, all_probs) resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs) if all_probs: return jsonify({"sentences": resp_sentences, "segments": resp_sentences_seg, Loading @@ -66,15 +70,12 @@ class MegatronGenerate(Resource): "segments": resp_sentences_seg, "logits": output_logits}) def index(): return current_app.send_static_file('index.html') class MegatronServer(object): def __init__(self, model): self.app = Flask(__name__) self.app.add_url_rule('/', 'index', index) self.app = Flask(__name__, static_folder='static', static_url_path='') self.app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 0 api = Api(self.app) api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model]) def run(self, url): self.app.run(url, threaded=False, debug=False) self.app.run(url, threaded=True, debug=False)
megatron/text_generation_utils.py +13 −15 Original line number Diff line number Diff line Loading @@ -104,12 +104,12 @@ def tokenize_batch(sentences): context_length_tensor = torch.cuda.LongTensor(context_lengths) return context_tokens_tensor, context_length_tensor def send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs): def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs): """ Needs to be synced up with receive_generate_info """ # Send the sizes of the tensors input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), max_len, all_probs] input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, all_probs] input_info_tensor = torch.cuda.LongTensor(input_info) torch.distributed.broadcast(input_info_tensor, 0) Loading @@ -125,7 +125,7 @@ def receive_generate_info(): torch.distributed.broadcast(input_info_tensor, 0) batch_size = input_info_tensor[0].item() seq_len = input_info_tensor[1].item() max_len = input_info_tensor[2].item() tokens_to_generate = input_info_tensor[2].item() all_probs = input_info_tensor[3].item() context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device()) Loading @@ -135,16 +135,16 @@ def receive_generate_info(): torch.distributed.broadcast(context_length_tensor, 0) torch.distributed.broadcast(context_tokens_tensor, 0) return context_length_tensor, context_tokens_tensor, max_len, all_probs return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len, all_probs): def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs): context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, max_len, tokens_to_generate, all_probs) for tokens, lengths, output_logits, full_logits in batch_token_iterator: context_length += 1 Loading Loading @@ -175,15 +175,15 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len if tokens is not None: return tokens[:, :context_length], output_logits, full_logits def generate(model, sentences=None, max_len=0, all_probs=False): def generate(model, sentences=None, tokens_to_generate=0, all_probs=False): model.eval() if torch.distributed.get_rank() == 0: context_tokens_tensor, context_length_tensor = tokenize_batch(sentences) send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs) send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs) else: context_length_tensor, context_tokens_tensor, max_len, all_probs = receive_generate_info() context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info() output = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len, all_probs) output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs) if output is not None: decode_tokens, output_logits, full_logits = output Loading Loading @@ -264,7 +264,7 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, maxlen, all_probs=False, type_ids=None): tokens_to_generate, all_probs=False, type_ids=None): args = get_args() tokenizer = get_tokenizer() Loading @@ -280,7 +280,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths, eos_id = tokenizer.eod counter = 0 org_context_length = context_length layer_past = None batch_size = context_tokens.size(0) Loading @@ -288,8 +287,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths, tokens = context_tokens output_logits = None # TODO(rprenger) maxlen should be named a different parameter maxlen = maxlen + org_context_length # Generate enough tokens for the longest sequence maxlen = tokens_to_generate + context_lengths.max().item() # TODO(rprenger) Need a better understanding of what args.seq_length vs args.out_seq_length (shouldn't be "args") if maxlen > args.seq_length: Loading Loading @@ -357,7 +356,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths, if all_probs: full_logits = torch.cat([full_logits, output_context], 1) #output_logits = torch.cat([output_logits, output[:,context_length,new_tokens]], 1) src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() torch.distributed.broadcast(new_tokens, src, group) Loading