feat: add MCP support (#20716)
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from typing import Literal, Optional
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
@@ -18,7 +19,7 @@ class ToolApiEntity(BaseModel):
|
||||
output_schema: Optional[dict] = None
|
||||
|
||||
|
||||
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
||||
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow", "mcp"]]
|
||||
|
||||
|
||||
class ToolProviderApiEntity(BaseModel):
|
||||
@@ -37,6 +38,10 @@ class ToolProviderApiEntity(BaseModel):
|
||||
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool")
|
||||
tools: list[ToolApiEntity] = Field(default_factory=list)
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
# MCP
|
||||
server_url: Optional[str] = Field(default="", description="The server url of the tool")
|
||||
updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
|
||||
server_identifier: Optional[str] = Field(default="", description="The server identifier of the MCP tool")
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
@@ -52,8 +57,13 @@ class ToolProviderApiEntity(BaseModel):
|
||||
for parameter in tool.get("parameters"):
|
||||
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
|
||||
parameter["type"] = "files"
|
||||
if parameter.get("input_schema") is None:
|
||||
parameter.pop("input_schema", None)
|
||||
# -------------
|
||||
|
||||
optional_fields = self.optional_field("server_url", self.server_url)
|
||||
if self.type == ToolProviderType.MCP.value:
|
||||
optional_fields.update(self.optional_field("updated_at", self.updated_at))
|
||||
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
|
||||
return {
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
@@ -69,4 +79,9 @@ class ToolProviderApiEntity(BaseModel):
|
||||
"allow_delete": self.allow_delete,
|
||||
"tools": tools,
|
||||
"labels": self.labels,
|
||||
**optional_fields,
|
||||
}
|
||||
|
||||
def optional_field(self, key: str, value: Any) -> dict:
|
||||
"""Return dict with key-value if value is truthy, empty dict otherwise."""
|
||||
return {key: value} if value else {}
|
||||
|
@@ -8,6 +8,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_seriali
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.parameters import (
|
||||
MCPServerParameterType,
|
||||
PluginParameter,
|
||||
PluginParameterOption,
|
||||
PluginParameterType,
|
||||
@@ -49,6 +50,7 @@ class ToolProviderType(enum.StrEnum):
|
||||
API = "api"
|
||||
APP = "app"
|
||||
DATASET_RETRIEVAL = "dataset-retrieval"
|
||||
MCP = "mcp"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ToolProviderType":
|
||||
@@ -242,6 +244,10 @@ class ToolParameter(PluginParameter):
|
||||
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
|
||||
DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value
|
||||
|
||||
# MCP object and array type parameters
|
||||
ARRAY = MCPServerParameterType.ARRAY.value
|
||||
OBJECT = MCPServerParameterType.OBJECT.value
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
|
||||
|
||||
@@ -260,6 +266,8 @@ class ToolParameter(PluginParameter):
|
||||
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
|
||||
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
|
||||
llm_description: Optional[str] = None
|
||||
# MCP object and array type parameters use this field to store the schema
|
||||
input_schema: Optional[dict] = None
|
||||
|
||||
@classmethod
|
||||
def get_simple_instance(
|
||||
|
130
api/core/tools/mcp_tool/provider.py
Normal file
130
api/core/tools/mcp_tool/provider.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from core.mcp.types import Tool as RemoteMCPTool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolProviderEntityWithPlugin,
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
class MCPToolProviderController(ToolProviderController):
|
||||
provider_id: str
|
||||
entity: ToolProviderEntityWithPlugin
|
||||
|
||||
def __init__(self, entity: ToolProviderEntityWithPlugin, provider_id: str, tenant_id: str, server_url: str) -> None:
|
||||
super().__init__(entity)
|
||||
self.entity = entity
|
||||
self.tenant_id = tenant_id
|
||||
self.provider_id = provider_id
|
||||
self.server_url = server_url
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return ToolProviderType.MCP
|
||||
|
||||
@classmethod
|
||||
def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController":
|
||||
"""
|
||||
from db provider
|
||||
"""
|
||||
tools = []
|
||||
tools_data = json.loads(db_provider.tools)
|
||||
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in tools_data]
|
||||
user = db_provider.load_user()
|
||||
tools = [
|
||||
ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author=user.name if user else "Anonymous",
|
||||
name=remote_mcp_tool.name,
|
||||
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
|
||||
provider=db_provider.server_identifier,
|
||||
icon=db_provider.icon,
|
||||
),
|
||||
parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
|
||||
description=ToolDescription(
|
||||
human=I18nObject(
|
||||
en_US=remote_mcp_tool.description or "", zh_Hans=remote_mcp_tool.description or ""
|
||||
),
|
||||
llm=remote_mcp_tool.description or "",
|
||||
),
|
||||
output_schema=None,
|
||||
has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0,
|
||||
)
|
||||
for remote_mcp_tool in remote_mcp_tools
|
||||
]
|
||||
|
||||
return cls(
|
||||
entity=ToolProviderEntityWithPlugin(
|
||||
identity=ToolProviderIdentity(
|
||||
author=user.name if user else "Anonymous",
|
||||
name=db_provider.name,
|
||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
icon=db_provider.icon,
|
||||
),
|
||||
plugin_id=None,
|
||||
credentials_schema=[],
|
||||
tools=tools,
|
||||
),
|
||||
provider_id=db_provider.server_identifier or "",
|
||||
tenant_id=db_provider.tenant_id or "",
|
||||
server_url=db_provider.decrypted_server_url,
|
||||
)
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_tool(self, tool_name: str) -> MCPTool: # type: ignore
|
||||
"""
|
||||
return tool with given name
|
||||
"""
|
||||
tool_entity = next(
|
||||
(tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name), None
|
||||
)
|
||||
|
||||
if not tool_entity:
|
||||
raise ValueError(f"Tool with name {tool_name} not found")
|
||||
|
||||
return MCPTool(
|
||||
entity=tool_entity,
|
||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
server_url=self.server_url,
|
||||
provider_id=self.provider_id,
|
||||
)
|
||||
|
||||
def get_tools(self) -> list[MCPTool]: # type: ignore
|
||||
"""
|
||||
get all tools
|
||||
"""
|
||||
return [
|
||||
MCPTool(
|
||||
entity=tool_entity,
|
||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
server_url=self.server_url,
|
||||
provider_id=self.provider_id,
|
||||
)
|
||||
for tool_entity in self.entity.tools
|
||||
]
|
92
api/core/tools/mcp_tool/tool.py
Normal file
92
api/core/tools/mcp_tool/tool.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.mcp.types import ImageContent, TextContent
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
runtime_parameters: Optional[list[ToolParameter]]
|
||||
server_url: str
|
||||
provider_id: str
|
||||
|
||||
def __init__(
|
||||
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.runtime_parameters = None
|
||||
self.server_url = server_url
|
||||
self.provider_id = provider_id
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.MCP
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
try:
|
||||
with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client:
|
||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPAuthError as e:
|
||||
raise ToolInvokeError("Please auth the tool first") from e
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
|
||||
for content in result.content:
|
||||
if isinstance(content, TextContent):
|
||||
try:
|
||||
content_json = json.loads(content.text)
|
||||
if isinstance(content_json, dict):
|
||||
yield self.create_json_message(content_json)
|
||||
elif isinstance(content_json, list):
|
||||
for item in content_json:
|
||||
yield self.create_json_message(item)
|
||||
else:
|
||||
yield self.create_text_message(content.text)
|
||||
except json.JSONDecodeError:
|
||||
yield self.create_text_message(content.text)
|
||||
|
||||
elif isinstance(content, ImageContent):
|
||||
yield self.create_blob_message(
|
||||
blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}
|
||||
)
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
|
||||
return MCPTool(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.icon,
|
||||
server_url=self.server_url,
|
||||
provider_id=self.provider_id,
|
||||
)
|
||||
|
||||
def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
in mcp tool invoke, if the parameter is empty, it will be set to None
|
||||
"""
|
||||
return {
|
||||
key: value
|
||||
for key, value in parameter.items()
|
||||
if value is not None and not (isinstance(value, str) and value.strip() == "")
|
||||
}
|
@@ -4,7 +4,7 @@ import mimetypes
|
||||
from collections.abc import Generator
|
||||
from os import listdir, path
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
||||
|
||||
from yarl import URL
|
||||
|
||||
@@ -13,9 +13,13 @@ from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
@@ -49,7 +53,7 @@ from core.tools.utils.configuration import (
|
||||
)
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -156,7 +160,7 @@ class ToolManager:
|
||||
tenant_id: str,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
|
||||
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]:
|
||||
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
|
||||
"""
|
||||
get the tool runtime
|
||||
|
||||
@@ -292,6 +296,8 @@ class ToolManager:
|
||||
raise NotImplementedError("app provider not implemented")
|
||||
elif provider_type == ToolProviderType.PLUGIN:
|
||||
return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
|
||||
elif provider_type == ToolProviderType.MCP:
|
||||
return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
|
||||
else:
|
||||
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
|
||||
|
||||
@@ -302,6 +308,7 @@ class ToolManager:
|
||||
app_id: str,
|
||||
agent_tool: AgentToolEntity,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: Optional[VariablePool] = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get the agent tool runtime
|
||||
@@ -316,24 +323,9 @@ class ToolManager:
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_entity.get_merged_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
# check file types
|
||||
if (
|
||||
parameter.type
|
||||
in {
|
||||
ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||
ToolParameter.ToolParameterType.FILE,
|
||||
ToolParameter.ToolParameterType.FILES,
|
||||
}
|
||||
and parameter.required
|
||||
):
|
||||
raise ValueError(f"file type parameter {parameter.name} not supported in agent")
|
||||
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
# save tool parameter to tool entity memory
|
||||
value = parameter.init_frontend_parameter(agent_tool.tool_parameters.get(parameter.name))
|
||||
runtime_parameters[parameter.name] = value
|
||||
|
||||
runtime_parameters = cls._convert_tool_parameters_type(
|
||||
parameters, variable_pool, agent_tool.tool_parameters, typ="agent"
|
||||
)
|
||||
# decrypt runtime parameters
|
||||
encryption_manager = ToolParameterConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
@@ -357,10 +349,12 @@ class ToolManager:
|
||||
node_id: str,
|
||||
workflow_tool: "ToolEntity",
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: Optional[VariablePool] = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get the workflow tool runtime
|
||||
"""
|
||||
|
||||
tool_runtime = cls.get_tool_runtime(
|
||||
provider_type=workflow_tool.provider_type,
|
||||
provider_id=workflow_tool.provider_id,
|
||||
@@ -369,15 +363,11 @@ class ToolManager:
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
)
|
||||
runtime_parameters = {}
|
||||
|
||||
parameters = tool_runtime.get_merged_runtime_parameters()
|
||||
|
||||
for parameter in parameters:
|
||||
# save tool parameter to tool entity memory
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
value = parameter.init_frontend_parameter(workflow_tool.tool_configurations.get(parameter.name))
|
||||
runtime_parameters[parameter.name] = value
|
||||
|
||||
runtime_parameters = cls._convert_tool_parameters_type(
|
||||
parameters, variable_pool, workflow_tool.tool_configurations, typ="workflow"
|
||||
)
|
||||
# decrypt runtime parameters
|
||||
encryption_manager = ToolParameterConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
@@ -569,7 +559,7 @@ class ToolManager:
|
||||
|
||||
filters = []
|
||||
if not typ:
|
||||
filters.extend(["builtin", "api", "workflow"])
|
||||
filters.extend(["builtin", "api", "workflow", "mcp"])
|
||||
else:
|
||||
filters.append(typ)
|
||||
|
||||
@@ -663,6 +653,10 @@ class ToolManager:
|
||||
labels=labels.get(provider_controller.provider_id, []),
|
||||
)
|
||||
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
||||
if "mcp" in filters:
|
||||
mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id, for_list=True)
|
||||
for mcp_provider in mcp_providers:
|
||||
result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
|
||||
|
||||
return BuiltinToolProviderSort.sort(list(result_providers.values()))
|
||||
|
||||
@@ -698,6 +692,32 @@ class ToolManager:
|
||||
|
||||
return controller, provider.credentials
|
||||
|
||||
@classmethod
|
||||
def get_mcp_provider_controller(cls, tenant_id: str, provider_id: str) -> MCPToolProviderController:
|
||||
"""
|
||||
get the api provider
|
||||
|
||||
:param tenant_id: the id of the tenant
|
||||
:param provider_id: the id of the provider
|
||||
|
||||
:return: the provider controller, the credentials
|
||||
"""
|
||||
provider: MCPToolProvider | None = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.filter(
|
||||
MCPToolProvider.server_identifier == provider_id,
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||
|
||||
controller = MCPToolProviderController._from_db(provider)
|
||||
|
||||
return controller
|
||||
|
||||
@classmethod
|
||||
def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
|
||||
"""
|
||||
@@ -826,6 +846,22 @@ class ToolManager:
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict[str, str] | str:
|
||||
try:
|
||||
mcp_provider: MCPToolProvider | None = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if mcp_provider is None:
|
||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||
|
||||
return mcp_provider.provider_icon
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
def get_tool_icon(
|
||||
cls,
|
||||
@@ -863,8 +899,61 @@ class ToolManager:
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
raise ValueError(f"plugin provider {provider_id} not found")
|
||||
elif provider_type == ToolProviderType.MCP:
|
||||
return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
|
||||
else:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
|
||||
@classmethod
|
||||
def _convert_tool_parameters_type(
|
||||
cls,
|
||||
parameters: list[ToolParameter],
|
||||
variable_pool: Optional[VariablePool],
|
||||
tool_configurations: dict[str, Any],
|
||||
typ: Literal["agent", "workflow", "tool"] = "workflow",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert tool parameters type
|
||||
"""
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.nodes.tool.exc import ToolParameterError
|
||||
|
||||
runtime_parameters = {}
|
||||
for parameter in parameters:
|
||||
if (
|
||||
parameter.type
|
||||
in {
|
||||
ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||
ToolParameter.ToolParameterType.FILE,
|
||||
ToolParameter.ToolParameterType.FILES,
|
||||
}
|
||||
and parameter.required
|
||||
and typ == "agent"
|
||||
):
|
||||
raise ValueError(f"file type parameter {parameter.name} not supported in agent")
|
||||
# save tool parameter to tool entity memory
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
if variable_pool:
|
||||
config = tool_configurations.get(parameter.name, {})
|
||||
if not (config and isinstance(config, dict) and config.get("value") is not None):
|
||||
continue
|
||||
tool_input = ToolNodeData.ToolInput(**tool_configurations.get(parameter.name, {}))
|
||||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
if variable is None:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type in {"mixed", "constant"}:
|
||||
segment_group = variable_pool.convert_template(str(tool_input.value))
|
||||
parameter_value = segment_group.text
|
||||
else:
|
||||
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
|
||||
runtime_parameters[parameter.name] = parameter_value
|
||||
|
||||
else:
|
||||
value = parameter.init_frontend_parameter(tool_configurations.get(parameter.name))
|
||||
runtime_parameters[parameter.name] = value
|
||||
return runtime_parameters
|
||||
|
||||
|
||||
ToolManager.load_hardcoded_providers_cache()
|
||||
|
@@ -72,21 +72,21 @@ class ProviderConfigEncrypter(BaseModel):
|
||||
|
||||
return data
|
||||
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
def decrypt(self, data: dict[str, str], use_cache: bool = True) -> dict[str, str]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cached_credentials = cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
if use_cache:
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cached_credentials = cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
@@ -105,7 +105,8 @@ class ProviderConfigEncrypter(BaseModel):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cache.set(data)
|
||||
if use_cache:
|
||||
cache.set(data)
|
||||
return data
|
||||
|
||||
def delete_tool_credentials_cache(self):
|
||||
|
@@ -8,7 +8,12 @@ from flask_login import current_user
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolEntity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping
|
||||
|
Reference in New Issue
Block a user