Loading megatron/text_generation_server.py +8 −1 Original line number Diff line number Diff line Loading @@ -55,8 +55,15 @@ class MegatronGenerate(Resource): if not isinstance(all_probs, bool): return "all_probs must be a boolean value" temperature = args.temperature if "temperature" in request.get_json(): temperature = request.get_json()["temperature"] if not isinstance(temperature, float) or not \ 0.0 < temperature <= 100.0: return "temperature must be a positive float less than or equal to 100.0" MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs) resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, temperature) if all_probs: return jsonify({"sentences": resp_sentences, "segments": resp_sentences_seg, Loading megatron/text_generation_utils.py +8 −8 Original line number Diff line number Diff line Loading @@ -138,14 +138,15 @@ def receive_generate_info(): return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs): def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature): context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, tokens_to_generate, all_probs) all_probs, temperature=temperature) for tokens, lengths, output_logits, full_logits in batch_token_iterator: context_length += 1 Loading Loading @@ -174,7 +175,7 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_ if tokens is not None: return tokens[:, :context_length], output_logits, full_logits def generate(model, sentences=None, tokens_to_generate=0, all_probs=False): def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, temperature=1.0): model.eval() if torch.distributed.get_rank() == 0: context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate) Loading @@ -182,8 +183,7 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False): else: context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info() output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs) output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature) if output is not None: decode_tokens, output_logits, full_logits = output Loading Loading @@ -262,7 +262,7 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokens_to_generate, all_probs=False, type_ids=None): tokens_to_generate, all_probs=False, type_ids=None, temperature=None): args = get_args() tokenizer = get_tokenizer() Loading Loading @@ -324,7 +324,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, prev = torch.argmax(logits, dim=-1).view(-1) else: logits = logits.float() logits /= args.temperature logits /= temperature logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) log_probs = F.softmax(logits, dim=-1) Loading Loading
megatron/text_generation_server.py +8 −1 Original line number Diff line number Diff line Loading @@ -55,8 +55,15 @@ class MegatronGenerate(Resource): if not isinstance(all_probs, bool): return "all_probs must be a boolean value" temperature = args.temperature if "temperature" in request.get_json(): temperature = request.get_json()["temperature"] if not isinstance(temperature, float) or not \ 0.0 < temperature <= 100.0: return "temperature must be a positive float less than or equal to 100.0" MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs) resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, temperature) if all_probs: return jsonify({"sentences": resp_sentences, "segments": resp_sentences_seg, Loading
megatron/text_generation_utils.py +8 −8 Original line number Diff line number Diff line Loading @@ -138,14 +138,15 @@ def receive_generate_info(): return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs): def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature): context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, tokens_to_generate, all_probs) all_probs, temperature=temperature) for tokens, lengths, output_logits, full_logits in batch_token_iterator: context_length += 1 Loading Loading @@ -174,7 +175,7 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_ if tokens is not None: return tokens[:, :context_length], output_logits, full_logits def generate(model, sentences=None, tokens_to_generate=0, all_probs=False): def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, temperature=1.0): model.eval() if torch.distributed.get_rank() == 0: context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate) Loading @@ -182,8 +183,7 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False): else: context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info() output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs) output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature) if output is not None: decode_tokens, output_logits, full_logits = output Loading Loading @@ -262,7 +262,7 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokens_to_generate, all_probs=False, type_ids=None): tokens_to_generate, all_probs=False, type_ids=None, temperature=None): args = get_args() tokenizer = get_tokenizer() Loading Loading @@ -324,7 +324,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, prev = torch.argmax(logits, dim=-1).view(-1) else: logits = logits.float() logits /= args.temperature logits /= temperature logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) log_probs = F.softmax(logits, dim=-1) Loading