Commit f47aa770 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'add-temperature-parameter-to-server-api' into 'main'

Add temperature to the server API

See merge request ADLR/megatron-lm!325
parents a97d676b 527e07c0
Loading
Loading
Loading
Loading
+8 −1
Original line number Diff line number Diff line
@@ -55,8 +55,15 @@ class MegatronGenerate(Resource):
            if not isinstance(all_probs, bool):
                return "all_probs must be a boolean value"

        temperature = args.temperature
        if "temperature" in request.get_json():
            temperature = request.get_json()["temperature"]
            if not isinstance(temperature, float) or not \
               0.0 < temperature <= 100.0:
                return "temperature must be a positive float less than or equal to 100.0"

        MegatronGenerate.send_do_generate()  # Tell other ranks we're doing generate
        resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs) 
        resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, temperature)
        if all_probs:
            return jsonify({"sentences": resp_sentences,
                "segments": resp_sentences_seg,
+8 −8
Original line number Diff line number Diff line
@@ -138,14 +138,15 @@ def receive_generate_info():
    
    return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs

def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs):
def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature):
    context_length = context_length_tensor.min().item()
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids,
                                                 tokens_to_generate,
                                                 all_probs)
                                                 all_probs,
                                                 temperature=temperature)
    for tokens, lengths, output_logits, full_logits in batch_token_iterator:
        context_length += 1
                
@@ -174,7 +175,7 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
    if tokens is not None:
        return tokens[:, :context_length], output_logits, full_logits 

def generate(model, sentences=None, tokens_to_generate=0, all_probs=False):
def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, temperature=1.0):
    model.eval()
    if torch.distributed.get_rank() == 0:
        context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate)
@@ -182,8 +183,7 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False):
    else:
        context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info()

    output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs)
    
    output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature)
    if output is not None:
        decode_tokens, output_logits, full_logits = output
        
@@ -262,7 +262,7 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,

def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          tokens_to_generate, all_probs=False, type_ids=None):
                          tokens_to_generate, all_probs=False, type_ids=None, temperature=None):
    args = get_args()
    tokenizer = get_tokenizer()

@@ -324,7 +324,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                    prev = torch.argmax(logits, dim=-1).view(-1)
                else:
                    logits = logits.float()
                    logits /= args.temperature
                    logits /= temperature
                    logits = top_k_logits(logits, top_k=args.top_k,
                                          top_p=args.top_p)
                    log_probs = F.softmax(logits, dim=-1)