Loading megatron/text_generation_server.py +5 −4 Original line number Diff line number Diff line Loading @@ -19,7 +19,6 @@ import threading from flask import Flask, request, jsonify, current_app from flask_restful import Resource, Api from megatron import get_args from megatron import mpu from megatron.text_generation import generate_and_post_process Loading Loading @@ -68,7 +67,7 @@ class MegatronGenerate(Resource): if not isinstance(logprobs, bool): return "logprobs must be a boolean value" temperature = args.temperature temperature = 1.0 if "temperature" in request.get_json(): temperature = request.get_json()["temperature"] if not (type(temperature) == int or type(temperature) == float): Loading @@ -76,7 +75,7 @@ class MegatronGenerate(Resource): if not (0.0 < temperature <= 100.0): return "temperature must be a positive number less than or equal to 100.0" top_k = args.top_k top_k = 0.0 if "top_k" in request.get_json(): top_k = request.get_json()["top_k"] if not (type(top_k) == int): Loading @@ -84,11 +83,13 @@ class MegatronGenerate(Resource): if not (0 < top_k <= 1000): return "top_k must be equal to or greater than 0 and less than or equal to 1000" top_p = args.top_p top_p = 0.0 if "top_p" in request.get_json(): top_p = request.get_json()["top_p"] if not (type(top_p) == float): return "top_p must be a positive float less than or equal to 1.0" if top_p > 0.0 and top_k > 0.0: return "cannot set both top-k and top-p samplings." if not (0 < top_p <= 1.0): return "top_p must be less than or equal to 1.0" Loading Loading
megatron/text_generation_server.py +5 −4 Original line number Diff line number Diff line Loading @@ -19,7 +19,6 @@ import threading from flask import Flask, request, jsonify, current_app from flask_restful import Resource, Api from megatron import get_args from megatron import mpu from megatron.text_generation import generate_and_post_process Loading Loading @@ -68,7 +67,7 @@ class MegatronGenerate(Resource): if not isinstance(logprobs, bool): return "logprobs must be a boolean value" temperature = args.temperature temperature = 1.0 if "temperature" in request.get_json(): temperature = request.get_json()["temperature"] if not (type(temperature) == int or type(temperature) == float): Loading @@ -76,7 +75,7 @@ class MegatronGenerate(Resource): if not (0.0 < temperature <= 100.0): return "temperature must be a positive number less than or equal to 100.0" top_k = args.top_k top_k = 0.0 if "top_k" in request.get_json(): top_k = request.get_json()["top_k"] if not (type(top_k) == int): Loading @@ -84,11 +83,13 @@ class MegatronGenerate(Resource): if not (0 < top_k <= 1000): return "top_k must be equal to or greater than 0 and less than or equal to 1000" top_p = args.top_p top_p = 0.0 if "top_p" in request.get_json(): top_p = request.get_json()["top_p"] if not (type(top_p) == float): return "top_p must be a positive float less than or equal to 1.0" if top_p > 0.0 and top_k > 0.0: return "cannot set both top-k and top-p samplings." if not (0 < top_p <= 1.0): return "top_p must be less than or equal to 1.0" Loading