feat: mypy for all type check (#10921)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from httpx import get
|
||||
|
||||
@@ -28,12 +29,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ApiToolManageService:
|
||||
@staticmethod
|
||||
def parser_api_schema(schema: str) -> list[ApiToolBundle]:
|
||||
def parser_api_schema(schema: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
parse api schema to tool bundle
|
||||
"""
|
||||
try:
|
||||
warnings = {}
|
||||
warnings: dict[str, str] = {}
|
||||
try:
|
||||
tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)
|
||||
except Exception as e:
|
||||
@@ -68,13 +69,16 @@ class ApiToolManageService:
|
||||
),
|
||||
]
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"schema_type": schema_type,
|
||||
"parameters_schema": tool_bundles,
|
||||
"credentials_schema": credentials_schema,
|
||||
"warning": warnings,
|
||||
}
|
||||
return cast(
|
||||
Mapping,
|
||||
jsonable_encoder(
|
||||
{
|
||||
"schema_type": schema_type,
|
||||
"parameters_schema": tool_bundles,
|
||||
"credentials_schema": credentials_schema,
|
||||
"warning": warnings,
|
||||
}
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
@@ -129,7 +133,7 @@ class ApiToolManageService:
|
||||
raise ValueError(f"provider {provider_name} already exists")
|
||||
|
||||
# parse openapi to tool bundle
|
||||
extra_info = {}
|
||||
extra_info: dict[str, str] = {}
|
||||
# extra info like description will be set here
|
||||
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
||||
|
||||
@@ -262,9 +266,8 @@ class ApiToolManageService:
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f"api provider {provider_name} does not exists")
|
||||
|
||||
# parse openapi to tool bundle
|
||||
extra_info = {}
|
||||
extra_info: dict[str, str] = {}
|
||||
# extra info like description will be set here
|
||||
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
||||
|
||||
@@ -416,7 +419,7 @@ class ApiToolManageService:
|
||||
provider_controller.validate_credentials_format(credentials)
|
||||
# get tool
|
||||
tool = provider_controller.get_tool(tool_name)
|
||||
tool = tool.fork_tool_runtime(
|
||||
runtime_tool = tool.fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
"tenant_id": tenant_id,
|
||||
@@ -454,7 +457,7 @@ class ApiToolManageService:
|
||||
|
||||
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
|
||||
|
||||
for tool in tools:
|
||||
for tool in tools or []:
|
||||
user_provider.tools.append(
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
|
||||
|
@@ -50,8 +50,8 @@ class BuiltinToolManageService:
|
||||
credentials = builtin_provider.credentials
|
||||
credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
|
||||
|
||||
result = []
|
||||
for tool in tools:
|
||||
result: list[UserTool] = []
|
||||
for tool in tools or []:
|
||||
result.append(
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tool=tool,
|
||||
@@ -217,6 +217,8 @@ class BuiltinToolManageService:
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
continue
|
||||
if provider_controller.identity is None:
|
||||
continue
|
||||
|
||||
# convert provider controller to user provider
|
||||
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
@@ -229,7 +231,7 @@ class BuiltinToolManageService:
|
||||
ToolTransformService.repack_provider(user_builtin_provider)
|
||||
|
||||
tools = provider_controller.get_tools()
|
||||
for tool in tools:
|
||||
for tool in tools or []:
|
||||
user_builtin_provider.tools.append(
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tenant_id=tenant_id,
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
@@ -35,7 +35,7 @@ class ToolTransformService:
|
||||
return url_prefix + "builtin/" + provider_name + "/icon"
|
||||
elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
|
||||
try:
|
||||
return json.loads(icon)
|
||||
return cast(dict, json.loads(icon))
|
||||
except:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@@ -53,8 +53,11 @@ class ToolTransformService:
|
||||
provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
|
||||
)
|
||||
elif isinstance(provider, UserToolProvider):
|
||||
provider.icon = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
|
||||
provider.icon = cast(
|
||||
str,
|
||||
ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -66,6 +69,9 @@ class ToolTransformService:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
if provider_controller.identity is None:
|
||||
raise ValueError("provider identity is None")
|
||||
|
||||
result = UserToolProvider(
|
||||
id=provider_controller.identity.name,
|
||||
author=provider_controller.identity.author,
|
||||
@@ -93,7 +99,8 @@ class ToolTransformService:
|
||||
# get credentials schema
|
||||
schema = provider_controller.get_credentials_schema()
|
||||
for name, value in schema.items():
|
||||
result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type)
|
||||
assert result.masked_credentials is not None, "masked credentials is None"
|
||||
result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(str(value.type))
|
||||
|
||||
# check if the provider need credentials
|
||||
if not provider_controller.need_credentials:
|
||||
@@ -149,6 +156,9 @@ class ToolTransformService:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
if provider_controller.identity is None:
|
||||
raise ValueError("provider identity is None")
|
||||
|
||||
return UserToolProvider(
|
||||
id=provider_controller.provider_id,
|
||||
author=provider_controller.identity.author,
|
||||
@@ -180,6 +190,8 @@ class ToolTransformService:
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
username = "Anonymous"
|
||||
if db_provider.user is None:
|
||||
raise ValueError(f"user is None for api provider {db_provider.id}")
|
||||
try:
|
||||
username = db_provider.user.name
|
||||
except Exception as e:
|
||||
@@ -256,19 +268,25 @@ class ToolTransformService:
|
||||
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
current_parameters.append(runtime_parameter)
|
||||
|
||||
if tool.identity is None:
|
||||
raise ValueError("tool identity is None")
|
||||
|
||||
return UserTool(
|
||||
author=tool.identity.author,
|
||||
name=tool.identity.name,
|
||||
label=tool.identity.label,
|
||||
description=tool.description.human,
|
||||
description=I18nObject(
|
||||
en_US=tool.description.human if tool.description else "",
|
||||
zh_Hans=tool.description.human if tool.description else "",
|
||||
),
|
||||
parameters=current_parameters,
|
||||
labels=labels,
|
||||
)
|
||||
if isinstance(tool, ApiToolBundle):
|
||||
return UserTool(
|
||||
author=tool.author,
|
||||
name=tool.operation_id,
|
||||
label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
|
||||
name=tool.operation_id or "",
|
||||
label=I18nObject(en_US=tool.operation_id or "", zh_Hans=tool.operation_id or ""),
|
||||
description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
|
||||
parameters=tool.parameters,
|
||||
labels=labels,
|
||||
|
@@ -6,8 +6,10 @@ from typing import Any, Optional
|
||||
from sqlalchemy import or_
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserToolProvider
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from extensions.ext_database import db
|
||||
@@ -32,7 +34,7 @@ class WorkflowToolManageService:
|
||||
label: str,
|
||||
icon: dict,
|
||||
description: str,
|
||||
parameters: Mapping[str, Any],
|
||||
parameters: list[Mapping[str, Any]],
|
||||
privacy_policy: str = "",
|
||||
labels: Optional[list[str]] = None,
|
||||
) -> dict:
|
||||
@@ -97,7 +99,7 @@ class WorkflowToolManageService:
|
||||
label: str,
|
||||
icon: dict,
|
||||
description: str,
|
||||
parameters: list[dict],
|
||||
parameters: list[Mapping[str, Any]],
|
||||
privacy_policy: str = "",
|
||||
labels: Optional[list[str]] = None,
|
||||
) -> dict:
|
||||
@@ -131,7 +133,7 @@ class WorkflowToolManageService:
|
||||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f"Tool with name {name} already exists")
|
||||
|
||||
workflow_tool_provider: WorkflowToolProvider = (
|
||||
workflow_tool_provider: Optional[WorkflowToolProvider] = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
@@ -140,14 +142,14 @@ class WorkflowToolManageService:
|
||||
if workflow_tool_provider is None:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
app: App = (
|
||||
app: Optional[App] = (
|
||||
db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
|
||||
)
|
||||
|
||||
if app is None:
|
||||
raise ValueError(f"App {workflow_tool_provider.app_id} not found")
|
||||
|
||||
workflow: Workflow = app.workflow
|
||||
workflow: Optional[Workflow] = app.workflow
|
||||
if workflow is None:
|
||||
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
|
||||
|
||||
@@ -193,7 +195,7 @@ class WorkflowToolManageService:
|
||||
# skip deleted tools
|
||||
pass
|
||||
|
||||
labels = ToolLabelManager.get_tools_labels(tools)
|
||||
labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)])
|
||||
|
||||
result = []
|
||||
|
||||
@@ -202,10 +204,11 @@ class WorkflowToolManageService:
|
||||
provider_controller=tool, labels=labels.get(tool.provider_id, [])
|
||||
)
|
||||
ToolTransformService.repack_provider(user_tool_provider)
|
||||
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
|
||||
if to_user_tool is None or len(to_user_tool) == 0:
|
||||
continue
|
||||
user_tool_provider.tools = [
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, [])
|
||||
)
|
||||
ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=labels.get(tool.provider_id, []))
|
||||
]
|
||||
result.append(user_tool_provider)
|
||||
|
||||
@@ -236,7 +239,7 @@ class WorkflowToolManageService:
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: WorkflowToolProvider = (
|
||||
db_tool: Optional[WorkflowToolProvider] = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
@@ -245,13 +248,19 @@ class WorkflowToolManageService:
|
||||
if db_tool is None:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
|
||||
workflow_app: Optional[App] = (
|
||||
db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
|
||||
)
|
||||
|
||||
if workflow_app is None:
|
||||
raise ValueError(f"App {db_tool.app_id} not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
|
||||
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
|
||||
if to_user_tool is None or len(to_user_tool) == 0:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
return {
|
||||
"name": db_tool.name,
|
||||
"label": db_tool.label,
|
||||
@@ -261,9 +270,9 @@ class WorkflowToolManageService:
|
||||
"description": db_tool.description,
|
||||
"parameters": jsonable_encoder(db_tool.parameter_configurations),
|
||||
"tool": ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
),
|
||||
"synced": workflow_app.workflow.version == db_tool.version,
|
||||
"synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False,
|
||||
"privacy_policy": db_tool.privacy_policy,
|
||||
}
|
||||
|
||||
@@ -276,7 +285,7 @@ class WorkflowToolManageService:
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: WorkflowToolProvider = (
|
||||
db_tool: Optional[WorkflowToolProvider] = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
|
||||
.first()
|
||||
@@ -285,12 +294,17 @@ class WorkflowToolManageService:
|
||||
if db_tool is None:
|
||||
raise ValueError(f"Tool {workflow_app_id} not found")
|
||||
|
||||
workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
|
||||
workflow_app: Optional[App] = (
|
||||
db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
|
||||
)
|
||||
|
||||
if workflow_app is None:
|
||||
raise ValueError(f"App {db_tool.app_id} not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
|
||||
if to_user_tool is None or len(to_user_tool) == 0:
|
||||
raise ValueError(f"Tool {workflow_app_id} not found")
|
||||
|
||||
return {
|
||||
"name": db_tool.name,
|
||||
@@ -301,14 +315,14 @@ class WorkflowToolManageService:
|
||||
"description": db_tool.description,
|
||||
"parameters": jsonable_encoder(db_tool.parameter_configurations),
|
||||
"tool": ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
),
|
||||
"synced": workflow_app.workflow.version == db_tool.version,
|
||||
"synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False,
|
||||
"privacy_policy": db_tool.privacy_policy,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]:
|
||||
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[UserTool]:
|
||||
"""
|
||||
List workflow tool provider tools.
|
||||
:param user_id: the user id
|
||||
@@ -316,7 +330,7 @@ class WorkflowToolManageService:
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the list of tools
|
||||
"""
|
||||
db_tool: WorkflowToolProvider = (
|
||||
db_tool: Optional[WorkflowToolProvider] = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
@@ -326,9 +340,8 @@ class WorkflowToolManageService:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
|
||||
if to_user_tool is None or len(to_user_tool) == 0:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
return [
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
)
|
||||
]
|
||||
return [ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool))]
|
||||
|
Reference in New Issue
Block a user