feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -1,6 +1,7 @@
import json
import logging
from typing import Optional
from collections.abc import Mapping
from typing import Any, Optional, cast
from httpx import get
@@ -28,12 +29,12 @@ logger = logging.getLogger(__name__)
class ApiToolManageService:
@staticmethod
def parser_api_schema(schema: str) -> list[ApiToolBundle]:
def parser_api_schema(schema: str) -> Mapping[str, Any]:
"""
parse api schema to tool bundle
"""
try:
warnings = {}
warnings: dict[str, str] = {}
try:
tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)
except Exception as e:
@@ -68,13 +69,16 @@ class ApiToolManageService:
),
]
return jsonable_encoder(
{
"schema_type": schema_type,
"parameters_schema": tool_bundles,
"credentials_schema": credentials_schema,
"warning": warnings,
}
return cast(
Mapping,
jsonable_encoder(
{
"schema_type": schema_type,
"parameters_schema": tool_bundles,
"credentials_schema": credentials_schema,
"warning": warnings,
}
),
)
except Exception as e:
raise ValueError(f"invalid schema: {str(e)}")
@@ -129,7 +133,7 @@ class ApiToolManageService:
raise ValueError(f"provider {provider_name} already exists")
# parse openapi to tool bundle
extra_info = {}
extra_info: dict[str, str] = {}
# extra info like description will be set here
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
@@ -262,9 +266,8 @@ class ApiToolManageService:
if provider is None:
raise ValueError(f"api provider {provider_name} does not exists")
# parse openapi to tool bundle
extra_info = {}
extra_info: dict[str, str] = {}
# extra info like description will be set here
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
@@ -416,7 +419,7 @@ class ApiToolManageService:
provider_controller.validate_credentials_format(credentials)
# get tool
tool = provider_controller.get_tool(tool_name)
tool = tool.fork_tool_runtime(
runtime_tool = tool.fork_tool_runtime(
runtime={
"credentials": credentials,
"tenant_id": tenant_id,
@@ -454,7 +457,7 @@ class ApiToolManageService:
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
for tool in tools:
for tool in tools or []:
user_provider.tools.append(
ToolTransformService.tool_to_user_tool(
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels

View File

@@ -50,8 +50,8 @@ class BuiltinToolManageService:
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
result = []
for tool in tools:
result: list[UserTool] = []
for tool in tools or []:
result.append(
ToolTransformService.tool_to_user_tool(
tool=tool,
@@ -217,6 +217,8 @@ class BuiltinToolManageService:
name_func=lambda x: x.identity.name,
):
continue
if provider_controller.identity is None:
continue
# convert provider controller to user provider
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
@@ -229,7 +231,7 @@ class BuiltinToolManageService:
ToolTransformService.repack_provider(user_builtin_provider)
tools = provider_controller.get_tools()
for tool in tools:
for tool in tools or []:
user_builtin_provider.tools.append(
ToolTransformService.tool_to_user_tool(
tenant_id=tenant_id,

View File

@@ -1,6 +1,6 @@
import json
import logging
from typing import Optional, Union
from typing import Optional, Union, cast
from configs import dify_config
from core.tools.entities.api_entities import UserTool, UserToolProvider
@@ -35,7 +35,7 @@ class ToolTransformService:
return url_prefix + "builtin/" + provider_name + "/icon"
elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
try:
return json.loads(icon)
return cast(dict, json.loads(icon))
except:
return {"background": "#252525", "content": "\ud83d\ude01"}
@@ -53,8 +53,11 @@ class ToolTransformService:
provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
)
elif isinstance(provider, UserToolProvider):
provider.icon = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
provider.icon = cast(
str,
ToolTransformService.get_tool_provider_icon_url(
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
),
)
@staticmethod
@@ -66,6 +69,9 @@ class ToolTransformService:
"""
convert provider controller to user provider
"""
if provider_controller.identity is None:
raise ValueError("provider identity is None")
result = UserToolProvider(
id=provider_controller.identity.name,
author=provider_controller.identity.author,
@@ -93,7 +99,8 @@ class ToolTransformService:
# get credentials schema
schema = provider_controller.get_credentials_schema()
for name, value in schema.items():
result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type)
assert result.masked_credentials is not None, "masked credentials is None"
result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(str(value.type))
# check if the provider need credentials
if not provider_controller.need_credentials:
@@ -149,6 +156,9 @@ class ToolTransformService:
"""
convert provider controller to user provider
"""
if provider_controller.identity is None:
raise ValueError("provider identity is None")
return UserToolProvider(
id=provider_controller.provider_id,
author=provider_controller.identity.author,
@@ -180,6 +190,8 @@ class ToolTransformService:
convert provider controller to user provider
"""
username = "Anonymous"
if db_provider.user is None:
raise ValueError(f"user is None for api provider {db_provider.id}")
try:
username = db_provider.user.name
except Exception as e:
@@ -256,19 +268,25 @@ class ToolTransformService:
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
if tool.identity is None:
raise ValueError("tool identity is None")
return UserTool(
author=tool.identity.author,
name=tool.identity.name,
label=tool.identity.label,
description=tool.description.human,
description=I18nObject(
en_US=tool.description.human if tool.description else "",
zh_Hans=tool.description.human if tool.description else "",
),
parameters=current_parameters,
labels=labels,
)
if isinstance(tool, ApiToolBundle):
return UserTool(
author=tool.author,
name=tool.operation_id,
label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
name=tool.operation_id or "",
label=I18nObject(en_US=tool.operation_id or "", zh_Hans=tool.operation_id or ""),
description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
parameters=tool.parameters,
labels=labels,

View File

@@ -6,8 +6,10 @@ from typing import Any, Optional
from sqlalchemy import or_
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserToolProvider
from core.tools.entities.api_entities import UserTool, UserToolProvider
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
from core.tools.tool.tool import Tool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from extensions.ext_database import db
@@ -32,7 +34,7 @@ class WorkflowToolManageService:
label: str,
icon: dict,
description: str,
parameters: Mapping[str, Any],
parameters: list[Mapping[str, Any]],
privacy_policy: str = "",
labels: Optional[list[str]] = None,
) -> dict:
@@ -97,7 +99,7 @@ class WorkflowToolManageService:
label: str,
icon: dict,
description: str,
parameters: list[dict],
parameters: list[Mapping[str, Any]],
privacy_policy: str = "",
labels: Optional[list[str]] = None,
) -> dict:
@@ -131,7 +133,7 @@ class WorkflowToolManageService:
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} already exists")
workflow_tool_provider: WorkflowToolProvider = (
workflow_tool_provider: Optional[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
@@ -140,14 +142,14 @@ class WorkflowToolManageService:
if workflow_tool_provider is None:
raise ValueError(f"Tool {workflow_tool_id} not found")
app: App = (
app: Optional[App] = (
db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
)
if app is None:
raise ValueError(f"App {workflow_tool_provider.app_id} not found")
workflow: Workflow = app.workflow
workflow: Optional[Workflow] = app.workflow
if workflow is None:
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
@@ -193,7 +195,7 @@ class WorkflowToolManageService:
# skip deleted tools
pass
labels = ToolLabelManager.get_tools_labels(tools)
labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)])
result = []
@@ -202,10 +204,11 @@ class WorkflowToolManageService:
provider_controller=tool, labels=labels.get(tool.provider_id, [])
)
ToolTransformService.repack_provider(user_tool_provider)
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
if to_user_tool is None or len(to_user_tool) == 0:
continue
user_tool_provider.tools = [
ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, [])
)
ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=labels.get(tool.provider_id, []))
]
result.append(user_tool_provider)
@@ -236,7 +239,7 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id
:return: the tool
"""
db_tool: WorkflowToolProvider = (
db_tool: Optional[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
@@ -245,13 +248,19 @@ class WorkflowToolManageService:
if db_tool is None:
raise ValueError(f"Tool {workflow_tool_id} not found")
workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
workflow_app: Optional[App] = (
db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
)
if workflow_app is None:
raise ValueError(f"App {db_tool.app_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
if to_user_tool is None or len(to_user_tool) == 0:
raise ValueError(f"Tool {workflow_tool_id} not found")
return {
"name": db_tool.name,
"label": db_tool.label,
@@ -261,9 +270,9 @@ class WorkflowToolManageService:
"description": db_tool.description,
"parameters": jsonable_encoder(db_tool.parameter_configurations),
"tool": ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool)
),
"synced": workflow_app.workflow.version == db_tool.version,
"synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False,
"privacy_policy": db_tool.privacy_policy,
}
@@ -276,7 +285,7 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id
:return: the tool
"""
db_tool: WorkflowToolProvider = (
db_tool: Optional[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.first()
@@ -285,12 +294,17 @@ class WorkflowToolManageService:
if db_tool is None:
raise ValueError(f"Tool {workflow_app_id} not found")
workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
workflow_app: Optional[App] = (
db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
)
if workflow_app is None:
raise ValueError(f"App {db_tool.app_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
if to_user_tool is None or len(to_user_tool) == 0:
raise ValueError(f"Tool {workflow_app_id} not found")
return {
"name": db_tool.name,
@@ -301,14 +315,14 @@ class WorkflowToolManageService:
"description": db_tool.description,
"parameters": jsonable_encoder(db_tool.parameter_configurations),
"tool": ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool)
),
"synced": workflow_app.workflow.version == db_tool.version,
"synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False,
"privacy_policy": db_tool.privacy_policy,
}
@classmethod
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]:
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[UserTool]:
"""
List workflow tool provider tools.
:param user_id: the user id
@@ -316,7 +330,7 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id
:return: the list of tools
"""
db_tool: WorkflowToolProvider = (
db_tool: Optional[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
@@ -326,9 +340,8 @@ class WorkflowToolManageService:
raise ValueError(f"Tool {workflow_tool_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
if to_user_tool is None or len(to_user_tool) == 0:
raise ValueError(f"Tool {workflow_tool_id} not found")
return [
ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
)
]
return [ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool))]