feat: add api-based extension & external data tool & moderation backend (#1403)
Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import re
|
||||
import uuid
|
||||
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from core.moderation.factory import ModerationFactory
|
||||
from core.prompt.prompt_transform import AppMode
|
||||
from core.agent.agent_executor import PlanningStrategy
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
@@ -13,8 +15,8 @@ SUPPORT_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current
|
||||
|
||||
|
||||
class AppModelConfigService:
|
||||
@staticmethod
|
||||
def is_dataset_exists(account: Account, dataset_id: str) -> bool:
|
||||
@classmethod
|
||||
def is_dataset_exists(cls, account: Account, dataset_id: str) -> bool:
|
||||
# verify if the dataset ID exists
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
||||
@@ -26,8 +28,8 @@ class AppModelConfigService:
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def validate_model_completion_params(cp: dict, model_name: str) -> dict:
|
||||
@classmethod
|
||||
def validate_model_completion_params(cls, cp: dict, model_name: str) -> dict:
|
||||
# 6. model.completion_params
|
||||
if not isinstance(cp, dict):
|
||||
raise ValueError("model.completion_params must be of object type")
|
||||
@@ -57,7 +59,7 @@ class AppModelConfigService:
|
||||
cp["stop"] = []
|
||||
elif not isinstance(cp["stop"], list):
|
||||
raise ValueError("stop in model.completion_params must be of list type")
|
||||
|
||||
|
||||
if len(cp["stop"]) > 4:
|
||||
raise ValueError("stop sequences must be less than 4")
|
||||
|
||||
@@ -73,8 +75,8 @@ class AppModelConfigService:
|
||||
|
||||
return filtered_cp
|
||||
|
||||
@staticmethod
|
||||
def validate_configuration(tenant_id: str, account: Account, config: dict, mode: str) -> dict:
|
||||
@classmethod
|
||||
def validate_configuration(cls, tenant_id: str, account: Account, config: dict, mode: str) -> dict:
|
||||
# opening_statement
|
||||
if 'opening_statement' not in config or not config["opening_statement"]:
|
||||
config["opening_statement"] = ""
|
||||
@@ -153,33 +155,6 @@ class AppModelConfigService:
|
||||
if not isinstance(config["more_like_this"]["enabled"], bool):
|
||||
raise ValueError("enabled in more_like_this must be of boolean type")
|
||||
|
||||
# sensitive_word_avoidance
|
||||
if 'sensitive_word_avoidance' not in config or not config["sensitive_word_avoidance"]:
|
||||
config["sensitive_word_avoidance"] = {
|
||||
"enabled": False
|
||||
}
|
||||
|
||||
if not isinstance(config["sensitive_word_avoidance"], dict):
|
||||
raise ValueError("sensitive_word_avoidance must be of dict type")
|
||||
|
||||
if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]:
|
||||
config["sensitive_word_avoidance"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["sensitive_word_avoidance"]["enabled"], bool):
|
||||
raise ValueError("enabled in sensitive_word_avoidance must be of boolean type")
|
||||
|
||||
if "words" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["words"]:
|
||||
config["sensitive_word_avoidance"]["words"] = ""
|
||||
|
||||
if not isinstance(config["sensitive_word_avoidance"]["words"], str):
|
||||
raise ValueError("words in sensitive_word_avoidance must be of string type")
|
||||
|
||||
if "canned_response" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["canned_response"]:
|
||||
config["sensitive_word_avoidance"]["canned_response"] = ""
|
||||
|
||||
if not isinstance(config["sensitive_word_avoidance"]["canned_response"], str):
|
||||
raise ValueError("canned_response in sensitive_word_avoidance must be of string type")
|
||||
|
||||
# model
|
||||
if 'model' not in config:
|
||||
raise ValueError("model is required")
|
||||
@@ -204,7 +179,7 @@ class AppModelConfigService:
|
||||
model_ids = [m['id'] for m in model_list]
|
||||
if config["model"]["name"] not in model_ids:
|
||||
raise ValueError("model.name must be in the specified model list")
|
||||
|
||||
|
||||
# model.mode
|
||||
if 'mode' not in config['model'] or not config['model']["mode"]:
|
||||
config['model']["mode"] = ""
|
||||
@@ -213,7 +188,7 @@ class AppModelConfigService:
|
||||
if 'completion_params' not in config["model"]:
|
||||
raise ValueError("model.completion_params is required")
|
||||
|
||||
config["model"]["completion_params"] = AppModelConfigService.validate_model_completion_params(
|
||||
config["model"]["completion_params"] = cls.validate_model_completion_params(
|
||||
config["model"]["completion_params"],
|
||||
config["model"]["name"]
|
||||
)
|
||||
@@ -330,14 +305,20 @@ class AppModelConfigService:
|
||||
except ValueError:
|
||||
raise ValueError("id in dataset must be of UUID type")
|
||||
|
||||
if not AppModelConfigService.is_dataset_exists(account, tool_item["id"]):
|
||||
if not cls.is_dataset_exists(account, tool_item["id"]):
|
||||
raise ValueError("Dataset ID does not exist, please check your permission.")
|
||||
|
||||
|
||||
# dataset_query_variable
|
||||
AppModelConfigService.is_dataset_query_variable_valid(config, mode)
|
||||
cls.is_dataset_query_variable_valid(config, mode)
|
||||
|
||||
# advanced prompt validation
|
||||
AppModelConfigService.is_advanced_prompt_valid(config, mode)
|
||||
cls.is_advanced_prompt_valid(config, mode)
|
||||
|
||||
# external data tools validation
|
||||
cls.is_external_data_tools_valid(tenant_id, config)
|
||||
|
||||
# moderation validation
|
||||
cls.is_moderation_valid(tenant_id, config)
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_config = {
|
||||
@@ -348,6 +329,7 @@ class AppModelConfigService:
|
||||
"retriever_resource": config["retriever_resource"],
|
||||
"more_like_this": config["more_like_this"],
|
||||
"sensitive_word_avoidance": config["sensitive_word_avoidance"],
|
||||
"external_data_tools": config["external_data_tools"],
|
||||
"model": {
|
||||
"provider": config["model"]["provider"],
|
||||
"name": config["model"]["name"],
|
||||
@@ -365,32 +347,86 @@ class AppModelConfigService:
|
||||
}
|
||||
|
||||
return filtered_config
|
||||
|
||||
@staticmethod
|
||||
def is_dataset_query_variable_valid(config: dict, mode: str) -> None:
|
||||
|
||||
@classmethod
|
||||
def is_moderation_valid(cls, tenant_id: str, config: dict):
|
||||
if 'sensitive_word_avoidance' not in config or not config["sensitive_word_avoidance"]:
|
||||
config["sensitive_word_avoidance"] = {
|
||||
"enabled": False
|
||||
}
|
||||
|
||||
if not isinstance(config["sensitive_word_avoidance"], dict):
|
||||
raise ValueError("sensitive_word_avoidance must be of dict type")
|
||||
|
||||
if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]:
|
||||
config["sensitive_word_avoidance"]["enabled"] = False
|
||||
|
||||
if not config["sensitive_word_avoidance"]["enabled"]:
|
||||
return
|
||||
|
||||
if "type" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["type"]:
|
||||
raise ValueError("sensitive_word_avoidance.type is required")
|
||||
|
||||
type = config["sensitive_word_avoidance"]["type"]
|
||||
config = config["sensitive_word_avoidance"]["config"]
|
||||
|
||||
ModerationFactory.validate_config(
|
||||
name=type,
|
||||
tenant_id=tenant_id,
|
||||
config=config
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_external_data_tools_valid(cls, tenant_id: str, config: dict):
|
||||
if 'external_data_tools' not in config or not config["external_data_tools"]:
|
||||
config["external_data_tools"] = []
|
||||
|
||||
if not isinstance(config["external_data_tools"], list):
|
||||
raise ValueError("external_data_tools must be of list type")
|
||||
|
||||
for tool in config["external_data_tools"]:
|
||||
if "enabled" not in tool or not tool["enabled"]:
|
||||
tool["enabled"] = False
|
||||
|
||||
if not tool["enabled"]:
|
||||
continue
|
||||
|
||||
if "type" not in tool or not tool["type"]:
|
||||
raise ValueError("external_data_tools[].type is required")
|
||||
|
||||
type = tool["type"]
|
||||
config = tool["config"]
|
||||
|
||||
ExternalDataToolFactory.validate_config(
|
||||
name=type,
|
||||
tenant_id=tenant_id,
|
||||
config=config
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_dataset_query_variable_valid(cls, config: dict, mode: str) -> None:
|
||||
# Only check when mode is completion
|
||||
if mode != 'completion':
|
||||
return
|
||||
|
||||
|
||||
agent_mode = config.get("agent_mode", {})
|
||||
tools = agent_mode.get("tools", [])
|
||||
dataset_exists = "dataset" in str(tools)
|
||||
|
||||
|
||||
dataset_query_variable = config.get("dataset_query_variable")
|
||||
|
||||
if dataset_exists and not dataset_query_variable:
|
||||
raise ValueError("Dataset query variable is required when dataset is exist")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def is_advanced_prompt_valid(config: dict, app_mode: str) -> None:
|
||||
@classmethod
|
||||
def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None:
|
||||
# prompt_type
|
||||
if 'prompt_type' not in config or not config["prompt_type"]:
|
||||
config["prompt_type"] = "simple"
|
||||
|
||||
if config['prompt_type'] not in ['simple', 'advanced']:
|
||||
raise ValueError("prompt_type must be in ['simple', 'advanced']")
|
||||
|
||||
|
||||
# chat_prompt_config
|
||||
if 'chat_prompt_config' not in config or not config["chat_prompt_config"]:
|
||||
config["chat_prompt_config"] = {}
|
||||
@@ -404,7 +440,7 @@ class AppModelConfigService:
|
||||
|
||||
if not isinstance(config["completion_prompt_config"], dict):
|
||||
raise ValueError("completion_prompt_config must be of object type")
|
||||
|
||||
|
||||
# dataset_configs
|
||||
if 'dataset_configs' not in config or not config["dataset_configs"]:
|
||||
config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}}
|
||||
@@ -415,10 +451,10 @@ class AppModelConfigService:
|
||||
if config['prompt_type'] == 'advanced':
|
||||
if not config['chat_prompt_config'] and not config['completion_prompt_config']:
|
||||
raise ValueError("chat_prompt_config or completion_prompt_config is required when prompt_type is advanced")
|
||||
|
||||
|
||||
if config['model']["mode"] not in ['chat', 'completion']:
|
||||
raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")
|
||||
|
||||
|
||||
if app_mode == AppMode.CHAT.value and config['model']["mode"] == ModelMode.COMPLETION.value:
|
||||
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
|
||||
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
|
||||
@@ -429,9 +465,8 @@ class AppModelConfigService:
|
||||
if not assistant_prefix:
|
||||
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
|
||||
|
||||
|
||||
if config['model']["mode"] == ModelMode.CHAT.value:
|
||||
prompt_list = config['chat_prompt_config']['prompt']
|
||||
|
||||
if len(prompt_list) > 10:
|
||||
raise ValueError("prompt messages must be less than 10")
|
||||
raise ValueError("prompt messages must be less than 10")
|
||||
|
Reference in New Issue
Block a user