Feat/improved mcp timeout configs (#23605)

Co-authored-by: crazywoola <427733928@qq.com>
This commit is contained in:
Will
2025-08-12 13:14:00 +08:00
committed by GitHub
parent d3eff9b1a3
commit 1ffe190557
18 changed files with 180 additions and 48 deletions

View File

@@ -12,8 +12,6 @@ from core.tools.errors import ToolProviderCredentialValidationError
class ToolProviderController(ABC):
entity: ToolProviderEntity
def __init__(self, entity: ToolProviderEntity) -> None:
self.entity = entity

View File

@@ -1,5 +1,5 @@
import json
from typing import Any
from typing import Any, Optional
from core.mcp.types import Tool as RemoteMCPTool
from core.tools.__base.tool_provider import ToolProviderController
@@ -19,15 +19,24 @@ from services.tools.tools_transform_service import ToolTransformService
class MCPToolProviderController(ToolProviderController):
provider_id: str
entity: ToolProviderEntityWithPlugin
def __init__(self, entity: ToolProviderEntityWithPlugin, provider_id: str, tenant_id: str, server_url: str) -> None:
def __init__(
self,
entity: ToolProviderEntityWithPlugin,
provider_id: str,
tenant_id: str,
server_url: str,
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = None,
sse_read_timeout: Optional[float] = None,
) -> None:
super().__init__(entity)
self.entity = entity
self.entity: ToolProviderEntityWithPlugin = entity
self.tenant_id = tenant_id
self.provider_id = provider_id
self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
@property
def provider_type(self) -> ToolProviderType:
@@ -85,6 +94,9 @@ class MCPToolProviderController(ToolProviderController):
provider_id=db_provider.server_identifier or "",
tenant_id=db_provider.tenant_id or "",
server_url=db_provider.decrypted_server_url,
headers={}, # TODO: get headers from db provider
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
)
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
@@ -111,6 +123,9 @@ class MCPToolProviderController(ToolProviderController):
icon=self.entity.identity.icon,
server_url=self.server_url,
provider_id=self.provider_id,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
def get_tools(self) -> list[MCPTool]: # type: ignore
@@ -125,6 +140,9 @@ class MCPToolProviderController(ToolProviderController):
icon=self.entity.identity.icon,
server_url=self.server_url,
provider_id=self.provider_id,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
for tool_entity in self.entity.tools
]

View File

@@ -13,13 +13,25 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too
class MCPTool(Tool):
def __init__(
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str
self,
entity: ToolEntity,
runtime: ToolRuntime,
tenant_id: str,
icon: str,
server_url: str,
provider_id: str,
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = None,
sse_read_timeout: Optional[float] = None,
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.server_url = server_url
self.provider_id = provider_id
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.MCP
@@ -35,7 +47,15 @@ class MCPTool(Tool):
from core.tools.errors import ToolInvokeError
try:
with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client:
with MCPClient(
self.server_url,
self.provider_id,
self.tenant_id,
authed=True,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
) as mcp_client:
tool_parameters = self._handle_none_parameter(tool_parameters)
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPAuthError as e:
@@ -72,6 +92,9 @@ class MCPTool(Tool):
icon=self.icon,
server_url=self.server_url,
provider_id=self.provider_id,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:

View File

@@ -789,9 +789,6 @@ class ToolManager:
"""
get api provider
"""
"""
get tool provider
"""
provider_name = provider
provider_obj: ApiToolProvider | None = (
db.session.query(ApiToolProvider)