feat: mypy for all type check (#10921)
This commit is contained in:
@@ -4,7 +4,7 @@ import mimetypes
|
||||
from collections.abc import Generator
|
||||
from os import listdir, path
|
||||
from threading import Lock, Thread
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.agent.entities import AgentToolEntity
|
||||
@@ -15,15 +15,18 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError
|
||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
|
||||
from core.tools.tool.api_tool import ApiTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
@@ -33,9 +36,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ToolManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_builtin_providers = {}
|
||||
_builtin_providers: dict[str, BuiltinToolProviderController] = {}
|
||||
_builtin_providers_loaded = False
|
||||
_builtin_tools_labels = {}
|
||||
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
|
||||
|
||||
@classmethod
|
||||
def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController:
|
||||
@@ -55,7 +58,7 @@ class ToolManager:
|
||||
return cls._builtin_providers[provider]
|
||||
|
||||
@classmethod
|
||||
def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool:
|
||||
def get_builtin_tool(cls, provider: str, tool_name: str) -> Union[BuiltinTool, Tool]:
|
||||
"""
|
||||
get the builtin tool
|
||||
|
||||
@@ -66,13 +69,15 @@ class ToolManager:
|
||||
"""
|
||||
provider_controller = cls.get_builtin_provider(provider)
|
||||
tool = provider_controller.get_tool(tool_name)
|
||||
if tool is None:
|
||||
raise ToolNotFoundError(f"tool {tool_name} not found")
|
||||
|
||||
return tool
|
||||
|
||||
@classmethod
|
||||
def get_tool(
|
||||
cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: Optional[str] = None
|
||||
) -> Union[BuiltinTool, ApiTool]:
|
||||
) -> Union[BuiltinTool, ApiTool, Tool]:
|
||||
"""
|
||||
get the tool
|
||||
|
||||
@@ -103,7 +108,7 @@ class ToolManager:
|
||||
tenant_id: str,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
|
||||
) -> Union[BuiltinTool, ApiTool]:
|
||||
) -> Union[BuiltinTool, ApiTool, Tool]:
|
||||
"""
|
||||
get the tool runtime
|
||||
|
||||
@@ -113,6 +118,7 @@ class ToolManager:
|
||||
|
||||
:return: the tool
|
||||
"""
|
||||
controller: Union[BuiltinToolProviderController, ApiToolProviderController, WorkflowToolProviderController]
|
||||
if provider_type == "builtin":
|
||||
builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
|
||||
|
||||
@@ -129,7 +135,7 @@ class ToolManager:
|
||||
)
|
||||
|
||||
# get credentials
|
||||
builtin_provider: BuiltinToolProvider = (
|
||||
builtin_provider: Optional[BuiltinToolProvider] = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
@@ -177,7 +183,7 @@ class ToolManager:
|
||||
}
|
||||
)
|
||||
elif provider_type == "workflow":
|
||||
workflow_provider = (
|
||||
workflow_provider: Optional[WorkflowToolProvider] = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
|
||||
.first()
|
||||
@@ -187,8 +193,13 @@ class ToolManager:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
|
||||
controller_tools: Optional[list[Tool]] = controller.get_tools(
|
||||
user_id="", tenant_id=workflow_provider.tenant_id
|
||||
)
|
||||
if controller_tools is None or len(controller_tools) == 0:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
|
||||
return controller_tools[0].fork_tool_runtime(
|
||||
runtime={
|
||||
"tenant_id": tenant_id,
|
||||
"credentials": {},
|
||||
@@ -215,7 +226,7 @@ class ToolManager:
|
||||
|
||||
if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
|
||||
# check if tool_parameter_config in options
|
||||
options = [x.value for x in parameter_rule.options]
|
||||
options = [x.value for x in parameter_rule.options or []]
|
||||
if parameter_value is not None and parameter_value not in options:
|
||||
raise ValueError(
|
||||
f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
|
||||
@@ -267,6 +278,8 @@ class ToolManager:
|
||||
identity_id=f"AGENT.{app_id}",
|
||||
)
|
||||
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
||||
if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None:
|
||||
raise ValueError("runtime not found or runtime parameters not found")
|
||||
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
return tool_entity
|
||||
@@ -312,6 +325,9 @@ class ToolManager:
|
||||
if runtime_parameters:
|
||||
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
||||
|
||||
if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None:
|
||||
raise ValueError("runtime not found or runtime parameters not found")
|
||||
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
return tool_entity
|
||||
|
||||
@@ -326,6 +342,8 @@ class ToolManager:
|
||||
"""
|
||||
# get provider
|
||||
provider_controller = cls.get_builtin_provider(provider)
|
||||
if provider_controller.identity is None:
|
||||
raise ToolProviderNotFoundError(f"builtin provider {provider} not found")
|
||||
|
||||
absolute_path = path.join(
|
||||
path.dirname(path.realpath(__file__)),
|
||||
@@ -381,11 +399,15 @@ class ToolManager:
|
||||
),
|
||||
parent_type=BuiltinToolProviderController,
|
||||
)
|
||||
provider: BuiltinToolProviderController = provider_class()
|
||||
cls._builtin_providers[provider.identity.name] = provider
|
||||
for tool in provider.get_tools():
|
||||
provider_controller: BuiltinToolProviderController = provider_class()
|
||||
if provider_controller.identity is None:
|
||||
continue
|
||||
cls._builtin_providers[provider_controller.identity.name] = provider_controller
|
||||
for tool in provider_controller.get_tools() or []:
|
||||
if tool.identity is None:
|
||||
continue
|
||||
cls._builtin_tools_labels[tool.identity.name] = tool.identity.label
|
||||
yield provider
|
||||
yield provider_controller
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"load builtin provider {provider}")
|
||||
@@ -449,9 +471,11 @@ class ToolManager:
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
# handle include, exclude
|
||||
if provider.identity is None:
|
||||
continue
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
|
||||
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
|
||||
data=provider,
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
@@ -472,7 +496,7 @@ class ToolManager:
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
api_provider_controllers = [
|
||||
api_provider_controllers: list[dict[str, Any]] = [
|
||||
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
|
||||
for provider in db_api_providers
|
||||
]
|
||||
@@ -495,7 +519,7 @@ class ToolManager:
|
||||
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
workflow_provider_controllers = []
|
||||
workflow_provider_controllers: list[WorkflowToolProviderController] = []
|
||||
for provider in workflow_providers:
|
||||
try:
|
||||
workflow_provider_controllers.append(
|
||||
@@ -505,7 +529,9 @@ class ToolManager:
|
||||
# app has been deleted
|
||||
pass
|
||||
|
||||
labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
|
||||
labels = ToolLabelManager.get_tools_labels(
|
||||
[cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
|
||||
)
|
||||
|
||||
for provider_controller in workflow_provider_controllers:
|
||||
user_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
@@ -527,7 +553,7 @@ class ToolManager:
|
||||
|
||||
:return: the provider controller, the credentials
|
||||
"""
|
||||
provider: ApiToolProvider = (
|
||||
provider: Optional[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.id == provider_id,
|
||||
@@ -556,7 +582,7 @@ class ToolManager:
|
||||
get tool provider
|
||||
"""
|
||||
provider_name = provider
|
||||
provider: ApiToolProvider = (
|
||||
provider_tool: Optional[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
@@ -565,17 +591,18 @@ class ToolManager:
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
if provider_tool is None:
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
|
||||
try:
|
||||
credentials = json.loads(provider.credentials_str) or {}
|
||||
credentials = json.loads(provider_tool.credentials_str) or {}
|
||||
except:
|
||||
credentials = {}
|
||||
|
||||
# package tool provider controller
|
||||
controller = ApiToolProviderController.from_db(
|
||||
provider, ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE
|
||||
provider_tool,
|
||||
ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
|
||||
)
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
||||
@@ -584,25 +611,28 @@ class ToolManager:
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
||||
|
||||
try:
|
||||
icon = json.loads(provider.icon)
|
||||
icon = json.loads(provider_tool.icon)
|
||||
except:
|
||||
icon = {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
# add tool labels
|
||||
labels = ToolLabelManager.get_tool_labels(controller)
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"schema_type": provider.schema_type,
|
||||
"schema": provider.schema,
|
||||
"tools": provider.tools,
|
||||
"icon": icon,
|
||||
"description": provider.description,
|
||||
"credentials": masked_credentials,
|
||||
"privacy_policy": provider.privacy_policy,
|
||||
"custom_disclaimer": provider.custom_disclaimer,
|
||||
"labels": labels,
|
||||
}
|
||||
return cast(
|
||||
dict,
|
||||
jsonable_encoder(
|
||||
{
|
||||
"schema_type": provider_tool.schema_type,
|
||||
"schema": provider_tool.schema,
|
||||
"tools": provider_tool.tools,
|
||||
"icon": icon,
|
||||
"description": provider_tool.description,
|
||||
"credentials": masked_credentials,
|
||||
"privacy_policy": provider_tool.privacy_policy,
|
||||
"custom_disclaimer": provider_tool.custom_disclaimer,
|
||||
"labels": labels,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -617,6 +647,7 @@ class ToolManager:
|
||||
"""
|
||||
provider_type = provider_type
|
||||
provider_id = provider_id
|
||||
provider: Optional[Union[BuiltinToolProvider, ApiToolProvider, WorkflowToolProvider]] = None
|
||||
if provider_type == "builtin":
|
||||
return (
|
||||
dify_config.CONSOLE_API_URL
|
||||
@@ -626,16 +657,21 @@ class ToolManager:
|
||||
)
|
||||
elif provider_type == "api":
|
||||
try:
|
||||
provider: ApiToolProvider = (
|
||||
provider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
|
||||
.first()
|
||||
)
|
||||
return json.loads(provider.icon)
|
||||
if provider is None:
|
||||
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
|
||||
icon = json.loads(provider.icon)
|
||||
if isinstance(icon, (str, dict)):
|
||||
return icon
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
except:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
elif provider_type == "workflow":
|
||||
provider: WorkflowToolProvider = (
|
||||
provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
|
||||
.first()
|
||||
@@ -643,7 +679,13 @@ class ToolManager:
|
||||
if provider is None:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
return json.loads(provider.icon)
|
||||
try:
|
||||
icon = json.loads(provider.icon)
|
||||
if isinstance(icon, (str, dict)):
|
||||
return icon
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
except:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
else:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
|
||||
|
Reference in New Issue
Block a user