Loading megatron/text_generation_server.py +45 −20 Original line number Diff line number Diff line Loading @@ -40,9 +40,18 @@ class MegatronGenerate(Resource): print(json.dumps(request.get_json()),flush=True) print("current time: ", datetime.datetime.now()) sentences = request.get_json()["sentences"] if len(sentences) > 128: return "Maximum number of sentences is 128", 400 if not "prompts" in request.get_json(): return "prompts argument required", 400 if "max_len" in request.get_json(): return "max_len is no longer used. Replace with tokens_to_generate", 400 if "sentences" in request.get_json(): return "sentences is no longer used. Replace with prompts", 400 prompts = request.get_json()["prompts"] if len(prompts) > 128: return "Maximum number of prompts is 128", 400 tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow if "tokens_to_generate" in request.get_json(): Loading @@ -52,11 +61,11 @@ class MegatronGenerate(Resource): if tokens_to_generate < 1: return "tokens_to_generate must be an integer greater than 0" all_probs = False if "all_probs" in request.get_json(): all_probs = request.get_json()["all_probs"] if not isinstance(all_probs, bool): return "all_probs must be a boolean value" logprobs = False if "logprobs" in request.get_json(): logprobs = request.get_json()["logprobs"] if not isinstance(logprobs, bool): return "logprobs must be a boolean value" temperature = args.temperature if "temperature" in request.get_json(): Loading @@ -66,6 +75,22 @@ class MegatronGenerate(Resource): if not (0.0 < temperature <= 100.0): return "temperature must be a positive number less than or equal to 100.0" top_k = args.top_k if "top_k" in request.get_json(): top_k = request.get_json()["top_k"] if not (type(top_k) == int): return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000" if not (0 < top_k <= 1000): return "top_k must be equal to or greater than 0 and less than or equal to 1000" top_p = args.top_p if "top_p" in request.get_json(): top_p = request.get_json()["top_p"] if not (type(top_p) == float): return "top_p must be a positive float less than or equal to 1.0" if not (0 < top_p <= 1.0): return "top_p must be less than or equal to 1.0" add_BOS = False if "add_BOS" in request.get_json(): add_BOS = request.get_json()["add_BOS"] Loading @@ -74,24 +99,24 @@ class MegatronGenerate(Resource): with lock: # Need to get lock to keep multiple threads from hitting code MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, temperature, add_BOS) if all_probs: return jsonify({"sentences": resp_sentences, "segments": resp_sentences_seg, "logits": output_logits, "all_logits": full_logits, "tokens": tokens}) response, response_seg, response_logprobs = generate(self.model, prompts, tokens_to_generate, logprobs, temperature, top_k, top_p, add_BOS) return jsonify({"sentences": resp_sentences, "segments": resp_sentences_seg, "logits": output_logits}) return jsonify({"text": response, "segments": response_seg, "logprobs": response_logprobs}) class MegatronServer(object): def __init__(self, model): self.app = Flask(__name__, static_url_path='') api = Api(self.app) api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model]) api.add_resource(MegatronGenerate, '/api', resource_class_args=[model]) def run(self, url): self.app.run(url, threaded=True, debug=False) megatron/text_generation_utils.py +68 −68 Original line number Diff line number Diff line Loading @@ -108,12 +108,12 @@ def tokenize_batch(sentences, max_len, add_BOS): context_length_tensor = torch.cuda.LongTensor(context_lengths) return context_tokens_tensor, context_length_tensor def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature): def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p): """ Needs to be synced up with receive_generate_info """ # Send the sizes of the tensors input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, all_probs, temperature] input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, logprobs, temperature, top_k, top_p] input_info_tensor = torch.cuda.FloatTensor(input_info) torch.distributed.broadcast(input_info_tensor, 0) Loading @@ -125,13 +125,15 @@ def receive_generate_info(): """ Needs to be synced up with send_generate_info """ input_info_tensor = torch.empty(5, dtype=torch.float32, device=torch.cuda.current_device()) input_info_tensor = torch.empty(7, dtype=torch.float32, device=torch.cuda.current_device()) torch.distributed.broadcast(input_info_tensor, 0) batch_size = int(input_info_tensor[0].item()) seq_len = int(input_info_tensor[1].item()) tokens_to_generate = int(input_info_tensor[2].item()) all_probs = int(input_info_tensor[3].item()) logprobs = bool(input_info_tensor[3].item()) temperature = float(input_info_tensor[4].item()) top_k = int(input_info_tensor[5].item()) top_p = float(input_info_tensor[6].item()) context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device()) context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device()) Loading @@ -140,28 +142,31 @@ def receive_generate_info(): torch.distributed.broadcast(context_length_tensor, 0) torch.distributed.broadcast(context_tokens_tensor, 0) return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs, temperature return context_length_tensor, context_tokens_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature): def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p): context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, attention_mask, position_ids, tokens_to_generate, all_probs, temperature=temperature) for tokens, lengths, output_logits, full_logits in batch_token_iterator: logprobs, temperature, top_k, top_p) for tokens, lengths, output_logits in batch_token_iterator: context_length += 1 if logprobs: if mpu.is_pipeline_last_stage(): src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() torch.distributed.broadcast(output_logits, src, group) if all_probs: src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() torch.distributed.broadcast(full_logits, src, group) else: if mpu.is_pipeline_first_stage(): Loading @@ -170,26 +175,20 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_ output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda")) torch.distributed.broadcast(output_logits, src, group) if all_probs: args = get_args() src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda")) torch.distributed.broadcast(full_logits, src, group) if tokens is not None: return tokens[:, :context_length], output_logits, full_logits return tokens[:, :context_length], output_logits def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, temperature=1.0, add_BOS=False): def generate(model, sentences=None, tokens_to_generate=0, logprobs=False, temperature=1.0, top_k=0, top_p=0.0, add_BOS=False): model.eval() if torch.distributed.get_rank() == 0: context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate, add_BOS) send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature) send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p) else: context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs, temperature = receive_generate_info() context_length_tensor, context_tokens_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p = receive_generate_info() output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature) output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p) if output is not None: decode_tokens, output_logits, full_logits = output decode_tokens, output_logits = output args = get_args() tokenizer = get_tokenizer() Loading @@ -197,7 +196,8 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe resp_sentences_seg = [] decode_tokens = decode_tokens.cpu().numpy().tolist() for decode_token in decode_tokens: for i, decode_token in enumerate(decode_tokens): resp_sentences.append(tokenizer.detokenize(decode_token)) words = [] for token in decode_token: Loading @@ -206,11 +206,9 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe words.append(word) resp_sentences_seg.append(words) if logprobs: output_logits = output_logits.cpu().numpy().tolist() if all_probs: full_logits = full_logits.cpu().numpy().tolist() return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens return resp_sentences, resp_sentences_seg, output_logits def generate_samples_eval(model, context, max_gen_length, eos_token_id): """ Loading Loading @@ -260,9 +258,17 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, return output_tensor def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokens_to_generate, all_probs=False, type_ids=None, temperature=None): def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokens_to_generate, logprobs, temperature, top_k, top_p, type_ids=None): args = get_args() tokenizer = get_tokenizer() Loading Loading @@ -330,8 +336,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths, else: logits = logits.float() logits /= temperature logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) logits = top_k_logits(logits, top_k=top_k, top_p=top_p) log_probs = F.softmax(logits, dim=-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1) started = context_lengths <= context_length Loading @@ -344,12 +350,11 @@ def sample_sequence_batch(model, context_tokens, context_lengths, tokens[:, context_length].view(-1), prev, started) tokens[:, context_length] = new_tokens if logprobs: if output_logits is None: output_context = F.log_softmax(output[:, :context_length, :], 2) indices = torch.unsqueeze(tokens[:, 1:context_length+1],2) output_logits = torch.gather(output_context, 2, indices).squeeze(2) if all_probs: full_logits = output_context else: output_context = F.log_softmax(output, 2) indices = torch.unsqueeze(new_tokens,1).unsqueeze(2) Loading @@ -357,8 +362,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths, # TODO(rprenger) we're copying output_logits every time. Should pre-allocate output_logits = torch.cat([output_logits, new_output_logits],1) if all_probs: full_logits = torch.cat([full_logits, output_context], 1) src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() Loading @@ -373,10 +376,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) if all_probs: yield tokens, lengths, output_logits, full_logits else: yield tokens, lengths, output_logits, None yield tokens, lengths, output_logits else: if mpu.is_pipeline_first_stage(): Loading @@ -385,9 +385,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths, new_tokens = torch.empty_like(tokens[:, context_length]) torch.distributed.broadcast(new_tokens, src, group) tokens[:, context_length] = new_tokens yield tokens, None, None, None yield tokens, None, None else: yield None, None, None, None yield None, None, None done = torch.cuda.ByteTensor([0]) src = mpu.get_pipeline_model_parallel_last_rank() Loading tools/text_generation_cli.py +3 −3 Original line number Diff line number Diff line Loading @@ -25,10 +25,10 @@ if __name__ == "__main__": url = sys.argv[1] while True: sentence = raw_input("Enter prompt: ") max_len = int(input("Enter number tokens output: ")) data = json.dumps({"sentences": [sentence], "max_len":max_len}) tokens_to_generate = int(input("Enter number of tokens to generate: ")) data = json.dumps({"prompts": [sentence], "tokens_to_generate":tokens_to_generate}) req = PutRequest(url, data, {'Content-Type': 'application/json'}) response = urllib2.urlopen(req) resp_sentences = json.load(response) print("Megatron Response: ") print(resp_sentences["sentences"][0]) print(resp_sentences["text"][0]) Loading
megatron/text_generation_server.py +45 −20 Original line number Diff line number Diff line Loading @@ -40,9 +40,18 @@ class MegatronGenerate(Resource): print(json.dumps(request.get_json()),flush=True) print("current time: ", datetime.datetime.now()) sentences = request.get_json()["sentences"] if len(sentences) > 128: return "Maximum number of sentences is 128", 400 if not "prompts" in request.get_json(): return "prompts argument required", 400 if "max_len" in request.get_json(): return "max_len is no longer used. Replace with tokens_to_generate", 400 if "sentences" in request.get_json(): return "sentences is no longer used. Replace with prompts", 400 prompts = request.get_json()["prompts"] if len(prompts) > 128: return "Maximum number of prompts is 128", 400 tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow if "tokens_to_generate" in request.get_json(): Loading @@ -52,11 +61,11 @@ class MegatronGenerate(Resource): if tokens_to_generate < 1: return "tokens_to_generate must be an integer greater than 0" all_probs = False if "all_probs" in request.get_json(): all_probs = request.get_json()["all_probs"] if not isinstance(all_probs, bool): return "all_probs must be a boolean value" logprobs = False if "logprobs" in request.get_json(): logprobs = request.get_json()["logprobs"] if not isinstance(logprobs, bool): return "logprobs must be a boolean value" temperature = args.temperature if "temperature" in request.get_json(): Loading @@ -66,6 +75,22 @@ class MegatronGenerate(Resource): if not (0.0 < temperature <= 100.0): return "temperature must be a positive number less than or equal to 100.0" top_k = args.top_k if "top_k" in request.get_json(): top_k = request.get_json()["top_k"] if not (type(top_k) == int): return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000" if not (0 < top_k <= 1000): return "top_k must be equal to or greater than 0 and less than or equal to 1000" top_p = args.top_p if "top_p" in request.get_json(): top_p = request.get_json()["top_p"] if not (type(top_p) == float): return "top_p must be a positive float less than or equal to 1.0" if not (0 < top_p <= 1.0): return "top_p must be less than or equal to 1.0" add_BOS = False if "add_BOS" in request.get_json(): add_BOS = request.get_json()["add_BOS"] Loading @@ -74,24 +99,24 @@ class MegatronGenerate(Resource): with lock: # Need to get lock to keep multiple threads from hitting code MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, temperature, add_BOS) if all_probs: return jsonify({"sentences": resp_sentences, "segments": resp_sentences_seg, "logits": output_logits, "all_logits": full_logits, "tokens": tokens}) response, response_seg, response_logprobs = generate(self.model, prompts, tokens_to_generate, logprobs, temperature, top_k, top_p, add_BOS) return jsonify({"sentences": resp_sentences, "segments": resp_sentences_seg, "logits": output_logits}) return jsonify({"text": response, "segments": response_seg, "logprobs": response_logprobs}) class MegatronServer(object): def __init__(self, model): self.app = Flask(__name__, static_url_path='') api = Api(self.app) api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model]) api.add_resource(MegatronGenerate, '/api', resource_class_args=[model]) def run(self, url): self.app.run(url, threaded=True, debug=False)
megatron/text_generation_utils.py +68 −68 Original line number Diff line number Diff line Loading @@ -108,12 +108,12 @@ def tokenize_batch(sentences, max_len, add_BOS): context_length_tensor = torch.cuda.LongTensor(context_lengths) return context_tokens_tensor, context_length_tensor def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature): def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p): """ Needs to be synced up with receive_generate_info """ # Send the sizes of the tensors input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, all_probs, temperature] input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, logprobs, temperature, top_k, top_p] input_info_tensor = torch.cuda.FloatTensor(input_info) torch.distributed.broadcast(input_info_tensor, 0) Loading @@ -125,13 +125,15 @@ def receive_generate_info(): """ Needs to be synced up with send_generate_info """ input_info_tensor = torch.empty(5, dtype=torch.float32, device=torch.cuda.current_device()) input_info_tensor = torch.empty(7, dtype=torch.float32, device=torch.cuda.current_device()) torch.distributed.broadcast(input_info_tensor, 0) batch_size = int(input_info_tensor[0].item()) seq_len = int(input_info_tensor[1].item()) tokens_to_generate = int(input_info_tensor[2].item()) all_probs = int(input_info_tensor[3].item()) logprobs = bool(input_info_tensor[3].item()) temperature = float(input_info_tensor[4].item()) top_k = int(input_info_tensor[5].item()) top_p = float(input_info_tensor[6].item()) context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device()) context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device()) Loading @@ -140,28 +142,31 @@ def receive_generate_info(): torch.distributed.broadcast(context_length_tensor, 0) torch.distributed.broadcast(context_tokens_tensor, 0) return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs, temperature return context_length_tensor, context_tokens_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature): def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p): context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, attention_mask, position_ids, tokens_to_generate, all_probs, temperature=temperature) for tokens, lengths, output_logits, full_logits in batch_token_iterator: logprobs, temperature, top_k, top_p) for tokens, lengths, output_logits in batch_token_iterator: context_length += 1 if logprobs: if mpu.is_pipeline_last_stage(): src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() torch.distributed.broadcast(output_logits, src, group) if all_probs: src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() torch.distributed.broadcast(full_logits, src, group) else: if mpu.is_pipeline_first_stage(): Loading @@ -170,26 +175,20 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_ output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda")) torch.distributed.broadcast(output_logits, src, group) if all_probs: args = get_args() src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size, dtype=torch.float32, device=torch.device("cuda")) torch.distributed.broadcast(full_logits, src, group) if tokens is not None: return tokens[:, :context_length], output_logits, full_logits return tokens[:, :context_length], output_logits def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, temperature=1.0, add_BOS=False): def generate(model, sentences=None, tokens_to_generate=0, logprobs=False, temperature=1.0, top_k=0, top_p=0.0, add_BOS=False): model.eval() if torch.distributed.get_rank() == 0: context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate, add_BOS) send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature) send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p) else: context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs, temperature = receive_generate_info() context_length_tensor, context_tokens_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p = receive_generate_info() output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature) output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, logprobs, temperature, top_k, top_p) if output is not None: decode_tokens, output_logits, full_logits = output decode_tokens, output_logits = output args = get_args() tokenizer = get_tokenizer() Loading @@ -197,7 +196,8 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe resp_sentences_seg = [] decode_tokens = decode_tokens.cpu().numpy().tolist() for decode_token in decode_tokens: for i, decode_token in enumerate(decode_tokens): resp_sentences.append(tokenizer.detokenize(decode_token)) words = [] for token in decode_token: Loading @@ -206,11 +206,9 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe words.append(word) resp_sentences_seg.append(words) if logprobs: output_logits = output_logits.cpu().numpy().tolist() if all_probs: full_logits = full_logits.cpu().numpy().tolist() return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens return resp_sentences, resp_sentences_seg, output_logits def generate_samples_eval(model, context, max_gen_length, eos_token_id): """ Loading Loading @@ -260,9 +258,17 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, return output_tensor def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokens_to_generate, all_probs=False, type_ids=None, temperature=None): def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokens_to_generate, logprobs, temperature, top_k, top_p, type_ids=None): args = get_args() tokenizer = get_tokenizer() Loading Loading @@ -330,8 +336,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths, else: logits = logits.float() logits /= temperature logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) logits = top_k_logits(logits, top_k=top_k, top_p=top_p) log_probs = F.softmax(logits, dim=-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1) started = context_lengths <= context_length Loading @@ -344,12 +350,11 @@ def sample_sequence_batch(model, context_tokens, context_lengths, tokens[:, context_length].view(-1), prev, started) tokens[:, context_length] = new_tokens if logprobs: if output_logits is None: output_context = F.log_softmax(output[:, :context_length, :], 2) indices = torch.unsqueeze(tokens[:, 1:context_length+1],2) output_logits = torch.gather(output_context, 2, indices).squeeze(2) if all_probs: full_logits = output_context else: output_context = F.log_softmax(output, 2) indices = torch.unsqueeze(new_tokens,1).unsqueeze(2) Loading @@ -357,8 +362,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths, # TODO(rprenger) we're copying output_logits every time. Should pre-allocate output_logits = torch.cat([output_logits, new_output_logits],1) if all_probs: full_logits = torch.cat([full_logits, output_context], 1) src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() Loading @@ -373,10 +376,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) if all_probs: yield tokens, lengths, output_logits, full_logits else: yield tokens, lengths, output_logits, None yield tokens, lengths, output_logits else: if mpu.is_pipeline_first_stage(): Loading @@ -385,9 +385,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths, new_tokens = torch.empty_like(tokens[:, context_length]) torch.distributed.broadcast(new_tokens, src, group) tokens[:, context_length] = new_tokens yield tokens, None, None, None yield tokens, None, None else: yield None, None, None, None yield None, None, None done = torch.cuda.ByteTensor([0]) src = mpu.get_pipeline_model_parallel_last_rank() Loading
tools/text_generation_cli.py +3 −3 Original line number Diff line number Diff line Loading @@ -25,10 +25,10 @@ if __name__ == "__main__": url = sys.argv[1] while True: sentence = raw_input("Enter prompt: ") max_len = int(input("Enter number tokens output: ")) data = json.dumps({"sentences": [sentence], "max_len":max_len}) tokens_to_generate = int(input("Enter number of tokens to generate: ")) data = json.dumps({"prompts": [sentence], "tokens_to_generate":tokens_to_generate}) req = PutRequest(url, data, {'Content-Type': 'application/json'}) response = urllib2.urlopen(req) resp_sentences = json.load(response) print("Megatron Response: ") print(resp_sentences["sentences"][0]) print(resp_sentences["text"][0])