feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -4,7 +4,7 @@ import mimetypes
from collections.abc import Generator
from os import listdir, path
from threading import Lock, Thread
from typing import Any, Optional, Union
from typing import Any, Optional, Union, cast
from configs import dify_config
from core.agent.entities import AgentToolEntity
@@ -15,15 +15,18 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter
from core.tools.errors import ToolProviderNotFoundError
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError
from core.tools.provider.api_tool_provider import ApiToolProviderController
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
from core.tools.tool.api_tool import ApiTool
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
from core.workflow.nodes.tool.entities import ToolEntity
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
from services.tools.tools_transform_service import ToolTransformService
@@ -33,9 +36,9 @@ logger = logging.getLogger(__name__)
class ToolManager:
_builtin_provider_lock = Lock()
_builtin_providers = {}
_builtin_providers: dict[str, BuiltinToolProviderController] = {}
_builtin_providers_loaded = False
_builtin_tools_labels = {}
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
@classmethod
def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController:
@@ -55,7 +58,7 @@ class ToolManager:
return cls._builtin_providers[provider]
@classmethod
def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool:
def get_builtin_tool(cls, provider: str, tool_name: str) -> Union[BuiltinTool, Tool]:
"""
get the builtin tool
@@ -66,13 +69,15 @@ class ToolManager:
"""
provider_controller = cls.get_builtin_provider(provider)
tool = provider_controller.get_tool(tool_name)
if tool is None:
raise ToolNotFoundError(f"tool {tool_name} not found")
return tool
@classmethod
def get_tool(
cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: Optional[str] = None
) -> Union[BuiltinTool, ApiTool]:
) -> Union[BuiltinTool, ApiTool, Tool]:
"""
get the tool
@@ -103,7 +108,7 @@ class ToolManager:
tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
) -> Union[BuiltinTool, ApiTool]:
) -> Union[BuiltinTool, ApiTool, Tool]:
"""
get the tool runtime
@@ -113,6 +118,7 @@ class ToolManager:
:return: the tool
"""
controller: Union[BuiltinToolProviderController, ApiToolProviderController, WorkflowToolProviderController]
if provider_type == "builtin":
builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
@@ -129,7 +135,7 @@ class ToolManager:
)
# get credentials
builtin_provider: BuiltinToolProvider = (
builtin_provider: Optional[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
@@ -177,7 +183,7 @@ class ToolManager:
}
)
elif provider_type == "workflow":
workflow_provider = (
workflow_provider: Optional[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
@@ -187,8 +193,13 @@ class ToolManager:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
controller_tools: Optional[list[Tool]] = controller.get_tools(
user_id="", tenant_id=workflow_provider.tenant_id
)
if controller_tools is None or len(controller_tools) == 0:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
return controller_tools[0].fork_tool_runtime(
runtime={
"tenant_id": tenant_id,
"credentials": {},
@@ -215,7 +226,7 @@ class ToolManager:
if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
# check if tool_parameter_config in options
options = [x.value for x in parameter_rule.options]
options = [x.value for x in parameter_rule.options or []]
if parameter_value is not None and parameter_value not in options:
raise ValueError(
f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
@@ -267,6 +278,8 @@ class ToolManager:
identity_id=f"AGENT.{app_id}",
)
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None:
raise ValueError("runtime not found or runtime parameters not found")
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
@@ -312,6 +325,9 @@ class ToolManager:
if runtime_parameters:
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None:
raise ValueError("runtime not found or runtime parameters not found")
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
@@ -326,6 +342,8 @@ class ToolManager:
"""
# get provider
provider_controller = cls.get_builtin_provider(provider)
if provider_controller.identity is None:
raise ToolProviderNotFoundError(f"builtin provider {provider} not found")
absolute_path = path.join(
path.dirname(path.realpath(__file__)),
@@ -381,11 +399,15 @@ class ToolManager:
),
parent_type=BuiltinToolProviderController,
)
provider: BuiltinToolProviderController = provider_class()
cls._builtin_providers[provider.identity.name] = provider
for tool in provider.get_tools():
provider_controller: BuiltinToolProviderController = provider_class()
if provider_controller.identity is None:
continue
cls._builtin_providers[provider_controller.identity.name] = provider_controller
for tool in provider_controller.get_tools() or []:
if tool.identity is None:
continue
cls._builtin_tools_labels[tool.identity.name] = tool.identity.label
yield provider
yield provider_controller
except Exception as e:
logger.exception(f"load builtin provider {provider}")
@@ -449,9 +471,11 @@ class ToolManager:
# append builtin providers
for provider in builtin_providers:
# handle include, exclude
if provider.identity is None:
continue
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
data=provider,
name_func=lambda x: x.identity.name,
):
@@ -472,7 +496,7 @@ class ToolManager:
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
)
api_provider_controllers = [
api_provider_controllers: list[dict[str, Any]] = [
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
for provider in db_api_providers
]
@@ -495,7 +519,7 @@ class ToolManager:
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
)
workflow_provider_controllers = []
workflow_provider_controllers: list[WorkflowToolProviderController] = []
for provider in workflow_providers:
try:
workflow_provider_controllers.append(
@@ -505,7 +529,9 @@ class ToolManager:
# app has been deleted
pass
labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
labels = ToolLabelManager.get_tools_labels(
[cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
)
for provider_controller in workflow_provider_controllers:
user_provider = ToolTransformService.workflow_provider_to_user_provider(
@@ -527,7 +553,7 @@ class ToolManager:
:return: the provider controller, the credentials
"""
provider: ApiToolProvider = (
provider: Optional[ApiToolProvider] = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.id == provider_id,
@@ -556,7 +582,7 @@ class ToolManager:
get tool provider
"""
provider_name = provider
provider: ApiToolProvider = (
provider_tool: Optional[ApiToolProvider] = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
@@ -565,17 +591,18 @@ class ToolManager:
.first()
)
if provider is None:
if provider_tool is None:
raise ValueError(f"you have not added provider {provider_name}")
try:
credentials = json.loads(provider.credentials_str) or {}
credentials = json.loads(provider_tool.credentials_str) or {}
except:
credentials = {}
# package tool provider controller
controller = ApiToolProviderController.from_db(
provider, ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE
provider_tool,
ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
)
# init tool configuration
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
@@ -584,25 +611,28 @@ class ToolManager:
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
try:
icon = json.loads(provider.icon)
icon = json.loads(provider_tool.icon)
except:
icon = {"background": "#252525", "content": "\ud83d\ude01"}
# add tool labels
labels = ToolLabelManager.get_tool_labels(controller)
return jsonable_encoder(
{
"schema_type": provider.schema_type,
"schema": provider.schema,
"tools": provider.tools,
"icon": icon,
"description": provider.description,
"credentials": masked_credentials,
"privacy_policy": provider.privacy_policy,
"custom_disclaimer": provider.custom_disclaimer,
"labels": labels,
}
return cast(
dict,
jsonable_encoder(
{
"schema_type": provider_tool.schema_type,
"schema": provider_tool.schema,
"tools": provider_tool.tools,
"icon": icon,
"description": provider_tool.description,
"credentials": masked_credentials,
"privacy_policy": provider_tool.privacy_policy,
"custom_disclaimer": provider_tool.custom_disclaimer,
"labels": labels,
}
),
)
@classmethod
@@ -617,6 +647,7 @@ class ToolManager:
"""
provider_type = provider_type
provider_id = provider_id
provider: Optional[Union[BuiltinToolProvider, ApiToolProvider, WorkflowToolProvider]] = None
if provider_type == "builtin":
return (
dify_config.CONSOLE_API_URL
@@ -626,16 +657,21 @@ class ToolManager:
)
elif provider_type == "api":
try:
provider: ApiToolProvider = (
provider = (
db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
.first()
)
return json.loads(provider.icon)
if provider is None:
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
icon = json.loads(provider.icon)
if isinstance(icon, (str, dict)):
return icon
return {"background": "#252525", "content": "\ud83d\ude01"}
except:
return {"background": "#252525", "content": "\ud83d\ude01"}
elif provider_type == "workflow":
provider: WorkflowToolProvider = (
provider = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
@@ -643,7 +679,13 @@ class ToolManager:
if provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
return json.loads(provider.icon)
try:
icon = json.loads(provider.icon)
if isinstance(icon, (str, dict)):
return icon
return {"background": "#252525", "content": "\ud83d\ude01"}
except:
return {"background": "#252525", "content": "\ud83d\ude01"}
else:
raise ValueError(f"provider type {provider_type} not found")