Improve: support custom model parameters in auto-generator (#22924)
This commit is contained in:
@@ -1,5 +1,3 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse
|
||||||
|
|
||||||
@@ -29,15 +27,12 @@ class RuleGenerateApi(Resource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
account = current_user
|
account = current_user
|
||||||
PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512"))
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rules = LLMGenerator.generate_rule_config(
|
rules = LLMGenerator.generate_rule_config(
|
||||||
tenant_id=account.current_tenant_id,
|
tenant_id=account.current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
no_variable=args["no_variable"],
|
no_variable=args["no_variable"],
|
||||||
rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS,
|
|
||||||
)
|
)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
@@ -64,14 +59,12 @@ class RuleCodeGenerateApi(Resource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
account = current_user
|
account = current_user
|
||||||
CODE_GENERATION_MAX_TOKENS = int(os.getenv("CODE_GENERATION_MAX_TOKENS", "1024"))
|
|
||||||
try:
|
try:
|
||||||
code_result = LLMGenerator.generate_code(
|
code_result = LLMGenerator.generate_code(
|
||||||
tenant_id=account.current_tenant_id,
|
tenant_id=account.current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
code_language=args["code_language"],
|
code_language=args["code_language"],
|
||||||
max_tokens=CODE_GENERATION_MAX_TOKENS,
|
|
||||||
)
|
)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
@@ -125,16 +125,13 @@ class LLMGenerator:
|
|||||||
return questions
|
return questions
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_rule_config(
|
def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool) -> dict:
|
||||||
cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool, rule_config_max_tokens: int = 512
|
|
||||||
) -> dict:
|
|
||||||
output_parser = RuleConfigGeneratorOutputParser()
|
output_parser = RuleConfigGeneratorOutputParser()
|
||||||
|
|
||||||
error = ""
|
error = ""
|
||||||
error_step = ""
|
error_step = ""
|
||||||
rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""}
|
rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""}
|
||||||
model_parameters = {"max_tokens": rule_config_max_tokens, "temperature": 0.01}
|
model_parameters = model_config.get("completion_params", {})
|
||||||
|
|
||||||
if no_variable:
|
if no_variable:
|
||||||
prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE)
|
prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE)
|
||||||
|
|
||||||
@@ -276,12 +273,7 @@ class LLMGenerator:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_code(
|
def generate_code(
|
||||||
cls,
|
cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"
|
||||||
tenant_id: str,
|
|
||||||
instruction: str,
|
|
||||||
model_config: dict,
|
|
||||||
code_language: str = "javascript",
|
|
||||||
max_tokens: int = 1000,
|
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if code_language == "python":
|
if code_language == "python":
|
||||||
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
|
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||||
@@ -305,8 +297,7 @@ class LLMGenerator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||||
model_parameters = {"max_tokens": max_tokens, "temperature": 0.01}
|
model_parameters = model_config.get("completion_params", {})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = cast(
|
response = cast(
|
||||||
LLMResult,
|
LLMResult,
|
||||||
|
Reference in New Issue
Block a user