Commit 8d405805 authored by rprenger's avatar rprenger
Browse files

Changing defaults and query sanitation to keep it from crashing on reasonable queries

parent 0dd5cc75
Loading
Loading
Loading
Loading
+5 −4
Original line number Diff line number Diff line
@@ -19,7 +19,6 @@ import threading
from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api
from megatron import get_args
from megatron import mpu
from megatron.text_generation import generate_and_post_process


@@ -68,7 +67,7 @@ class MegatronGenerate(Resource):
            if not isinstance(logprobs, bool):
                return "logprobs must be a boolean value"
        
        temperature = args.temperature
        temperature = 1.0
        if "temperature" in request.get_json():
            temperature = request.get_json()["temperature"]
            if not (type(temperature) == int or type(temperature) == float):
@@ -76,7 +75,7 @@ class MegatronGenerate(Resource):
            if not (0.0 < temperature <= 100.0):
                return "temperature must be a positive number less than or equal to 100.0"
        
        top_k = args.top_k
        top_k = 0.0
        if "top_k" in request.get_json():
            top_k = request.get_json()["top_k"]
            if not (type(top_k) == int):
@@ -84,11 +83,13 @@ class MegatronGenerate(Resource):
            if not (0 < top_k <= 1000):
                return "top_k must be equal to or greater than 0 and less than or equal to 1000"
        
        top_p = args.top_p
        top_p = 0.0
        if "top_p" in request.get_json():
            top_p = request.get_json()["top_p"]
            if not (type(top_p) == float):
                return "top_p must be a positive float less than or equal to 1.0"
            if top_p > 0.0 and top_k > 0.0:
                return "cannot set both top-k and top-p samplings."
            if not (0 < top_p <= 1.0):
                return "top_p must be less than or equal to 1.0"