feat: add MCP support (#20716)
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
@@ -3,11 +3,13 @@ import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
@@ -73,12 +75,14 @@ class AgentNode(ToolNode):
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=node_data,
|
||||
strategy=strategy,
|
||||
)
|
||||
parameters_for_log = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=node_data,
|
||||
for_log=True,
|
||||
strategy=strategy,
|
||||
)
|
||||
|
||||
# get conversation id
|
||||
@@ -155,6 +159,7 @@ class AgentNode(ToolNode):
|
||||
variable_pool: VariablePool,
|
||||
node_data: AgentNodeData,
|
||||
for_log: bool = False,
|
||||
strategy: PluginAgentStrategy,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
@@ -207,7 +212,7 @@ class AgentNode(ToolNode):
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
value = [tool for tool in value if tool.get("enabled", False)]
|
||||
|
||||
value = self._filter_mcp_type_tool(strategy, value)
|
||||
for tool in value:
|
||||
if "schemas" in tool:
|
||||
tool.pop("schemas")
|
||||
@@ -244,9 +249,9 @@ class AgentNode(ToolNode):
|
||||
)
|
||||
|
||||
extra = tool.get("extra", {})
|
||||
|
||||
runtime_variable_pool = variable_pool if self.node_data.version != "1" else None
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
self.tenant_id, self.app_id, entity, self.invoke_from
|
||||
self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool
|
||||
)
|
||||
if tool_runtime.entity.description:
|
||||
tool_runtime.entity.description.llm = (
|
||||
@@ -398,3 +403,16 @@ class AgentNode(ToolNode):
|
||||
except ValueError:
|
||||
model_schema.features.remove(feature)
|
||||
return model_schema
|
||||
|
||||
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Filter MCP type tool
|
||||
:param strategy: plugin agent strategy
|
||||
:param tool: tool
|
||||
:return: filtered tool dict
|
||||
"""
|
||||
meta_version = strategy.meta_version
|
||||
if meta_version and Version(meta_version) > Version("0.0.1"):
|
||||
return tools
|
||||
else:
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP.value]
|
||||
|
@@ -73,6 +73,7 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
||||
},
|
||||
NodeType.TOOL: {
|
||||
LATEST_VERSION: ToolNode,
|
||||
"2": ToolNode,
|
||||
"1": ToolNode,
|
||||
},
|
||||
NodeType.VARIABLE_AGGREGATOR: {
|
||||
@@ -122,6 +123,7 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
||||
},
|
||||
NodeType.AGENT: {
|
||||
LATEST_VERSION: AgentNode,
|
||||
"2": AgentNode,
|
||||
"1": AgentNode,
|
||||
},
|
||||
}
|
||||
|
@@ -41,6 +41,10 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
def check_type(cls, value, validation_info: ValidationInfo):
|
||||
typ = value
|
||||
value = validation_info.data.get("value")
|
||||
|
||||
if value is None:
|
||||
return typ
|
||||
|
||||
if typ == "mixed" and not isinstance(value, str):
|
||||
raise ValueError("value must be a string")
|
||||
elif typ == "variable":
|
||||
@@ -54,3 +58,22 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
return typ
|
||||
|
||||
tool_parameters: dict[str, ToolInput]
|
||||
|
||||
@field_validator("tool_parameters", mode="before")
|
||||
@classmethod
|
||||
def filter_none_tool_inputs(cls, value):
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
return {
|
||||
key: tool_input
|
||||
for key, tool_input in value.items()
|
||||
if tool_input is not None and cls._has_valid_value(tool_input)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _has_valid_value(tool_input):
|
||||
"""Check if the value is valid"""
|
||||
if isinstance(tool_input, dict):
|
||||
return tool_input.get("value") is not None
|
||||
return getattr(tool_input, "value", None) is not None
|
||||
|
@@ -67,8 +67,9 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
try:
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool if self.node_data.version != "1" else None
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
|
||||
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from, variable_pool
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
yield RunCompletedEvent(
|
||||
@@ -95,7 +96,6 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
node_data=self.node_data,
|
||||
for_log=True,
|
||||
)
|
||||
|
||||
# get conversation id
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
|
Reference in New Issue
Block a user