feat: improved MCP timeout (#23546)
This commit is contained in:
@@ -862,6 +862,10 @@ class ToolProviderMCPApi(Resource):
|
|||||||
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
|
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
|
||||||
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||||
|
parser.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30)
|
||||||
|
parser.add_argument(
|
||||||
|
"sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
user = current_user
|
user = current_user
|
||||||
if not is_valid_url(args["server_url"]):
|
if not is_valid_url(args["server_url"]):
|
||||||
@@ -876,6 +880,8 @@ class ToolProviderMCPApi(Resource):
|
|||||||
icon_background=args["icon_background"],
|
icon_background=args["icon_background"],
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
server_identifier=args["server_identifier"],
|
server_identifier=args["server_identifier"],
|
||||||
|
timeout=args["timeout"],
|
||||||
|
sse_read_timeout=args["sse_read_timeout"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -891,6 +897,8 @@ class ToolProviderMCPApi(Resource):
|
|||||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||||
|
parser.add_argument("timeout", type=float, required=False, nullable=True, location="json")
|
||||||
|
parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if not is_valid_url(args["server_url"]):
|
if not is_valid_url(args["server_url"]):
|
||||||
if "[__HIDDEN__]" in args["server_url"]:
|
if "[__HIDDEN__]" in args["server_url"]:
|
||||||
@@ -906,6 +914,8 @@ class ToolProviderMCPApi(Resource):
|
|||||||
icon_type=args["icon_type"],
|
icon_type=args["icon_type"],
|
||||||
icon_background=args["icon_background"],
|
icon_background=args["icon_background"],
|
||||||
server_identifier=args["server_identifier"],
|
server_identifier=args["server_identifier"],
|
||||||
|
timeout=args.get("timeout"),
|
||||||
|
sse_read_timeout=args.get("sse_read_timeout"),
|
||||||
)
|
)
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
@@ -327,7 +327,7 @@ def send_message(http_client: httpx.Client, endpoint_url: str, session_message:
|
|||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
logger.debug("Client message sent successfully: %s", response.status_code)
|
logger.debug("Client message sent successfully: %s", response.status_code)
|
||||||
except Exception as exc:
|
except Exception:
|
||||||
logger.exception("Error sending message")
|
logger.exception("Error sending message")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@@ -55,14 +55,10 @@ DEFAULT_QUEUE_READ_TIMEOUT = 3
|
|||||||
class StreamableHTTPError(Exception):
|
class StreamableHTTPError(Exception):
|
||||||
"""Base exception for StreamableHTTP transport errors."""
|
"""Base exception for StreamableHTTP transport errors."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ResumptionError(StreamableHTTPError):
|
class ResumptionError(StreamableHTTPError):
|
||||||
"""Raised when resumption request is invalid."""
|
"""Raised when resumption request is invalid."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RequestContext:
|
class RequestContext:
|
||||||
@@ -74,7 +70,7 @@ class RequestContext:
|
|||||||
session_message: SessionMessage
|
session_message: SessionMessage
|
||||||
metadata: ClientMessageMetadata | None
|
metadata: ClientMessageMetadata | None
|
||||||
server_to_client_queue: ServerToClientQueue # Renamed for clarity
|
server_to_client_queue: ServerToClientQueue # Renamed for clarity
|
||||||
sse_read_timeout: timedelta
|
sse_read_timeout: float
|
||||||
|
|
||||||
|
|
||||||
class StreamableHTTPTransport:
|
class StreamableHTTPTransport:
|
||||||
@@ -84,8 +80,8 @@ class StreamableHTTPTransport:
|
|||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
headers: dict[str, Any] | None = None,
|
headers: dict[str, Any] | None = None,
|
||||||
timeout: timedelta = timedelta(seconds=30),
|
timeout: float | timedelta = 30,
|
||||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
|
sse_read_timeout: float | timedelta = 60 * 5,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the StreamableHTTP transport.
|
"""Initialize the StreamableHTTP transport.
|
||||||
|
|
||||||
@@ -97,8 +93,10 @@ class StreamableHTTPTransport:
|
|||||||
"""
|
"""
|
||||||
self.url = url
|
self.url = url
|
||||||
self.headers = headers or {}
|
self.headers = headers or {}
|
||||||
self.timeout = timeout
|
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
|
||||||
self.sse_read_timeout = sse_read_timeout
|
self.sse_read_timeout = (
|
||||||
|
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
|
||||||
|
)
|
||||||
self.session_id: str | None = None
|
self.session_id: str | None = None
|
||||||
self.request_headers = {
|
self.request_headers = {
|
||||||
ACCEPT: f"{JSON}, {SSE}",
|
ACCEPT: f"{JSON}, {SSE}",
|
||||||
@@ -186,7 +184,7 @@ class StreamableHTTPTransport:
|
|||||||
with ssrf_proxy_sse_connect(
|
with ssrf_proxy_sse_connect(
|
||||||
self.url,
|
self.url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
|
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
|
||||||
client=client,
|
client=client,
|
||||||
method="GET",
|
method="GET",
|
||||||
) as event_source:
|
) as event_source:
|
||||||
@@ -215,7 +213,7 @@ class StreamableHTTPTransport:
|
|||||||
with ssrf_proxy_sse_connect(
|
with ssrf_proxy_sse_connect(
|
||||||
self.url,
|
self.url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
|
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
|
||||||
client=ctx.client,
|
client=ctx.client,
|
||||||
method="GET",
|
method="GET",
|
||||||
) as event_source:
|
) as event_source:
|
||||||
@@ -402,8 +400,8 @@ class StreamableHTTPTransport:
|
|||||||
def streamablehttp_client(
|
def streamablehttp_client(
|
||||||
url: str,
|
url: str,
|
||||||
headers: dict[str, Any] | None = None,
|
headers: dict[str, Any] | None = None,
|
||||||
timeout: timedelta = timedelta(seconds=30),
|
timeout: float | timedelta = 30,
|
||||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
|
sse_read_timeout: float | timedelta = 60 * 5,
|
||||||
terminate_on_close: bool = True,
|
terminate_on_close: bool = True,
|
||||||
) -> Generator[
|
) -> Generator[
|
||||||
tuple[
|
tuple[
|
||||||
@@ -436,7 +434,7 @@ def streamablehttp_client(
|
|||||||
try:
|
try:
|
||||||
with create_ssrf_proxy_mcp_http_client(
|
with create_ssrf_proxy_mcp_http_client(
|
||||||
headers=transport.request_headers,
|
headers=transport.request_headers,
|
||||||
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
|
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
|
||||||
) as client:
|
) as client:
|
||||||
# Define callbacks that need access to thread pool
|
# Define callbacks that need access to thread pool
|
||||||
def start_get_stream() -> None:
|
def start_get_stream() -> None:
|
||||||
|
@@ -23,12 +23,18 @@ class MCPClient:
|
|||||||
authed: bool = True,
|
authed: bool = True,
|
||||||
authorization_code: Optional[str] = None,
|
authorization_code: Optional[str] = None,
|
||||||
for_list: bool = False,
|
for_list: bool = False,
|
||||||
|
headers: Optional[dict[str, str]] = None,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
sse_read_timeout: Optional[float] = None,
|
||||||
):
|
):
|
||||||
# Initialize info
|
# Initialize info
|
||||||
self.provider_id = provider_id
|
self.provider_id = provider_id
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
self.client_type = "streamable"
|
self.client_type = "streamable"
|
||||||
self.server_url = server_url
|
self.server_url = server_url
|
||||||
|
self.headers = headers or {}
|
||||||
|
self.timeout = timeout
|
||||||
|
self.sse_read_timeout = sse_read_timeout
|
||||||
|
|
||||||
# Authentication info
|
# Authentication info
|
||||||
self.authed = authed
|
self.authed = authed
|
||||||
@@ -43,7 +49,7 @@ class MCPClient:
|
|||||||
self._session: Optional[ClientSession] = None
|
self._session: Optional[ClientSession] = None
|
||||||
self._streams_context: Optional[AbstractContextManager[Any]] = None
|
self._streams_context: Optional[AbstractContextManager[Any]] = None
|
||||||
self._session_context: Optional[ClientSession] = None
|
self._session_context: Optional[ClientSession] = None
|
||||||
self.exit_stack = ExitStack()
|
self._exit_stack = ExitStack()
|
||||||
|
|
||||||
# Whether the client has been initialized
|
# Whether the client has been initialized
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
@@ -90,21 +96,26 @@ class MCPClient:
|
|||||||
headers = (
|
headers = (
|
||||||
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
|
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
|
||||||
if self.authed and self.token
|
if self.authed and self.token
|
||||||
else {}
|
else self.headers
|
||||||
|
)
|
||||||
|
self._streams_context = client_factory(
|
||||||
|
url=self.server_url,
|
||||||
|
headers=headers,
|
||||||
|
timeout=self.timeout,
|
||||||
|
sse_read_timeout=self.sse_read_timeout,
|
||||||
)
|
)
|
||||||
self._streams_context = client_factory(url=self.server_url, headers=headers)
|
|
||||||
if not self._streams_context:
|
if not self._streams_context:
|
||||||
raise MCPConnectionError("Failed to create connection context")
|
raise MCPConnectionError("Failed to create connection context")
|
||||||
|
|
||||||
# Use exit_stack to manage context managers properly
|
# Use exit_stack to manage context managers properly
|
||||||
if method_name == "mcp":
|
if method_name == "mcp":
|
||||||
read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context)
|
read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context)
|
||||||
streams = (read_stream, write_stream)
|
streams = (read_stream, write_stream)
|
||||||
else: # sse_client
|
else: # sse_client
|
||||||
streams = self.exit_stack.enter_context(self._streams_context)
|
streams = self._exit_stack.enter_context(self._streams_context)
|
||||||
|
|
||||||
self._session_context = ClientSession(*streams)
|
self._session_context = ClientSession(*streams)
|
||||||
self._session = self.exit_stack.enter_context(self._session_context)
|
self._session = self._exit_stack.enter_context(self._session_context)
|
||||||
session = cast(ClientSession, self._session)
|
session = cast(ClientSession, self._session)
|
||||||
session.initialize()
|
session.initialize()
|
||||||
return
|
return
|
||||||
@@ -120,9 +131,6 @@ class MCPClient:
|
|||||||
if first_try:
|
if first_try:
|
||||||
return self.connect_server(client_factory, method_name, first_try=False)
|
return self.connect_server(client_factory, method_name, first_try=False)
|
||||||
|
|
||||||
except MCPConnectionError:
|
|
||||||
raise
|
|
||||||
|
|
||||||
def list_tools(self) -> list[Tool]:
|
def list_tools(self) -> list[Tool]:
|
||||||
"""Connect to an MCP server running with SSE transport"""
|
"""Connect to an MCP server running with SSE transport"""
|
||||||
# List available tools to verify connection
|
# List available tools to verify connection
|
||||||
@@ -142,7 +150,7 @@ class MCPClient:
|
|||||||
"""Clean up resources"""
|
"""Clean up resources"""
|
||||||
try:
|
try:
|
||||||
# ExitStack will handle proper cleanup of all managed context managers
|
# ExitStack will handle proper cleanup of all managed context managers
|
||||||
self.exit_stack.close()
|
self._exit_stack.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("Error during cleanup")
|
logging.exception("Error during cleanup")
|
||||||
raise ValueError(f"Error during cleanup: {e}")
|
raise ValueError(f"Error during cleanup: {e}")
|
||||||
|
@@ -2,7 +2,6 @@ import logging
|
|||||||
import queue
|
import queue
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
|
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
|
||||||
from contextlib import ExitStack
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import Any, Generic, Self, TypeVar
|
from typing import Any, Generic, Self, TypeVar
|
||||||
@@ -170,7 +169,6 @@ class BaseSession(
|
|||||||
self._receive_notification_type = receive_notification_type
|
self._receive_notification_type = receive_notification_type
|
||||||
self._session_read_timeout_seconds = read_timeout_seconds
|
self._session_read_timeout_seconds = read_timeout_seconds
|
||||||
self._in_flight = {}
|
self._in_flight = {}
|
||||||
self._exit_stack = ExitStack()
|
|
||||||
# Initialize executor and future to None for proper cleanup checks
|
# Initialize executor and future to None for proper cleanup checks
|
||||||
self._executor: ThreadPoolExecutor | None = None
|
self._executor: ThreadPoolExecutor | None = None
|
||||||
self._receiver_future: Future | None = None
|
self._receiver_future: Future | None = None
|
||||||
@@ -377,7 +375,7 @@ class BaseSession(
|
|||||||
self._handle_incoming(RuntimeError(f"Server Error: {message}"))
|
self._handle_incoming(RuntimeError(f"Server Error: {message}"))
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logging.exception("Error in message processing loop")
|
logging.exception("Error in message processing loop")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -389,14 +387,12 @@ class BaseSession(
|
|||||||
If the request is responded to within this method, it will not be
|
If the request is responded to within this method, it will not be
|
||||||
forwarded on to the message stream.
|
forwarded on to the message stream.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def _received_notification(self, notification: ReceiveNotificationT) -> None:
|
def _received_notification(self, notification: ReceiveNotificationT) -> None:
|
||||||
"""
|
"""
|
||||||
Can be overridden by subclasses to handle a notification without needing
|
Can be overridden by subclasses to handle a notification without needing
|
||||||
to listen on the message stream.
|
to listen on the message stream.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def send_progress_notification(
|
def send_progress_notification(
|
||||||
self, progress_token: str | int, progress: float, total: float | None = None
|
self, progress_token: str | int, progress: float, total: float | None = None
|
||||||
@@ -405,11 +401,9 @@ class BaseSession(
|
|||||||
Sends a progress notification for a request that is currently being
|
Sends a progress notification for a request that is currently being
|
||||||
processed.
|
processed.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def _handle_incoming(
|
def _handle_incoming(
|
||||||
self,
|
self,
|
||||||
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
|
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""A generic handler for incoming messages. Overwritten by subclasses."""
|
"""A generic handler for incoming messages. Overwritten by subclasses."""
|
||||||
pass
|
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
import queue
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
|
|
||||||
@@ -85,8 +86,8 @@ class ClientSession(
|
|||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
read_stream,
|
read_stream: queue.Queue,
|
||||||
write_stream,
|
write_stream: queue.Queue,
|
||||||
read_timeout_seconds: timedelta | None = None,
|
read_timeout_seconds: timedelta | None = None,
|
||||||
sampling_callback: SamplingFnT | None = None,
|
sampling_callback: SamplingFnT | None = None,
|
||||||
list_roots_callback: ListRootsFnT | None = None,
|
list_roots_callback: ListRootsFnT | None = None,
|
||||||
|
@@ -12,8 +12,6 @@ from core.tools.errors import ToolProviderCredentialValidationError
|
|||||||
|
|
||||||
|
|
||||||
class ToolProviderController(ABC):
|
class ToolProviderController(ABC):
|
||||||
entity: ToolProviderEntity
|
|
||||||
|
|
||||||
def __init__(self, entity: ToolProviderEntity) -> None:
|
def __init__(self, entity: ToolProviderEntity) -> None:
|
||||||
self.entity = entity
|
self.entity = entity
|
||||||
|
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
from core.mcp.types import Tool as RemoteMCPTool
|
from core.mcp.types import Tool as RemoteMCPTool
|
||||||
from core.tools.__base.tool_provider import ToolProviderController
|
from core.tools.__base.tool_provider import ToolProviderController
|
||||||
@@ -19,15 +19,24 @@ from services.tools.tools_transform_service import ToolTransformService
|
|||||||
|
|
||||||
|
|
||||||
class MCPToolProviderController(ToolProviderController):
|
class MCPToolProviderController(ToolProviderController):
|
||||||
provider_id: str
|
def __init__(
|
||||||
entity: ToolProviderEntityWithPlugin
|
self,
|
||||||
|
entity: ToolProviderEntityWithPlugin,
|
||||||
def __init__(self, entity: ToolProviderEntityWithPlugin, provider_id: str, tenant_id: str, server_url: str) -> None:
|
provider_id: str,
|
||||||
|
tenant_id: str,
|
||||||
|
server_url: str,
|
||||||
|
headers: Optional[dict[str, str]] = None,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
sse_read_timeout: Optional[float] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__(entity)
|
super().__init__(entity)
|
||||||
self.entity = entity
|
self.entity: ToolProviderEntityWithPlugin = entity
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
self.provider_id = provider_id
|
self.provider_id = provider_id
|
||||||
self.server_url = server_url
|
self.server_url = server_url
|
||||||
|
self.headers = headers or {}
|
||||||
|
self.timeout = timeout
|
||||||
|
self.sse_read_timeout = sse_read_timeout
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_type(self) -> ToolProviderType:
|
def provider_type(self) -> ToolProviderType:
|
||||||
@@ -85,6 +94,9 @@ class MCPToolProviderController(ToolProviderController):
|
|||||||
provider_id=db_provider.server_identifier or "",
|
provider_id=db_provider.server_identifier or "",
|
||||||
tenant_id=db_provider.tenant_id or "",
|
tenant_id=db_provider.tenant_id or "",
|
||||||
server_url=db_provider.decrypted_server_url,
|
server_url=db_provider.decrypted_server_url,
|
||||||
|
headers={}, # TODO: get headers from db provider
|
||||||
|
timeout=db_provider.timeout,
|
||||||
|
sse_read_timeout=db_provider.sse_read_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||||
@@ -111,6 +123,9 @@ class MCPToolProviderController(ToolProviderController):
|
|||||||
icon=self.entity.identity.icon,
|
icon=self.entity.identity.icon,
|
||||||
server_url=self.server_url,
|
server_url=self.server_url,
|
||||||
provider_id=self.provider_id,
|
provider_id=self.provider_id,
|
||||||
|
headers=self.headers,
|
||||||
|
timeout=self.timeout,
|
||||||
|
sse_read_timeout=self.sse_read_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_tools(self) -> list[MCPTool]: # type: ignore
|
def get_tools(self) -> list[MCPTool]: # type: ignore
|
||||||
@@ -125,6 +140,9 @@ class MCPToolProviderController(ToolProviderController):
|
|||||||
icon=self.entity.identity.icon,
|
icon=self.entity.identity.icon,
|
||||||
server_url=self.server_url,
|
server_url=self.server_url,
|
||||||
provider_id=self.provider_id,
|
provider_id=self.provider_id,
|
||||||
|
headers=self.headers,
|
||||||
|
timeout=self.timeout,
|
||||||
|
sse_read_timeout=self.sse_read_timeout,
|
||||||
)
|
)
|
||||||
for tool_entity in self.entity.tools
|
for tool_entity in self.entity.tools
|
||||||
]
|
]
|
||||||
|
@@ -13,13 +13,25 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too
|
|||||||
|
|
||||||
class MCPTool(Tool):
|
class MCPTool(Tool):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str
|
self,
|
||||||
|
entity: ToolEntity,
|
||||||
|
runtime: ToolRuntime,
|
||||||
|
tenant_id: str,
|
||||||
|
icon: str,
|
||||||
|
server_url: str,
|
||||||
|
provider_id: str,
|
||||||
|
headers: Optional[dict[str, str]] = None,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
sse_read_timeout: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(entity, runtime)
|
super().__init__(entity, runtime)
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
self.icon = icon
|
self.icon = icon
|
||||||
self.server_url = server_url
|
self.server_url = server_url
|
||||||
self.provider_id = provider_id
|
self.provider_id = provider_id
|
||||||
|
self.headers = headers or {}
|
||||||
|
self.timeout = timeout
|
||||||
|
self.sse_read_timeout = sse_read_timeout
|
||||||
|
|
||||||
def tool_provider_type(self) -> ToolProviderType:
|
def tool_provider_type(self) -> ToolProviderType:
|
||||||
return ToolProviderType.MCP
|
return ToolProviderType.MCP
|
||||||
@@ -35,7 +47,15 @@ class MCPTool(Tool):
|
|||||||
from core.tools.errors import ToolInvokeError
|
from core.tools.errors import ToolInvokeError
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client:
|
with MCPClient(
|
||||||
|
self.server_url,
|
||||||
|
self.provider_id,
|
||||||
|
self.tenant_id,
|
||||||
|
authed=True,
|
||||||
|
headers=self.headers,
|
||||||
|
timeout=self.timeout,
|
||||||
|
sse_read_timeout=self.sse_read_timeout,
|
||||||
|
) as mcp_client:
|
||||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||||
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||||
except MCPAuthError as e:
|
except MCPAuthError as e:
|
||||||
@@ -72,6 +92,9 @@ class MCPTool(Tool):
|
|||||||
icon=self.icon,
|
icon=self.icon,
|
||||||
server_url=self.server_url,
|
server_url=self.server_url,
|
||||||
provider_id=self.provider_id,
|
provider_id=self.provider_id,
|
||||||
|
headers=self.headers,
|
||||||
|
timeout=self.timeout,
|
||||||
|
sse_read_timeout=self.sse_read_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:
|
def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
@@ -789,9 +789,6 @@ class ToolManager:
|
|||||||
"""
|
"""
|
||||||
get api provider
|
get api provider
|
||||||
"""
|
"""
|
||||||
"""
|
|
||||||
get tool provider
|
|
||||||
"""
|
|
||||||
provider_name = provider
|
provider_name = provider
|
||||||
provider_obj: ApiToolProvider | None = (
|
provider_obj: ApiToolProvider | None = (
|
||||||
db.session.query(ApiToolProvider)
|
db.session.query(ApiToolProvider)
|
||||||
|
@@ -0,0 +1,47 @@
|
|||||||
|
"""add timeout for tool_mcp_providers
|
||||||
|
|
||||||
|
Revision ID: fa8b0fa6f407
|
||||||
|
Revises: 532b3f888abf
|
||||||
|
Create Date: 2025-08-07 11:15:31.517985
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'fa8b0fa6f407'
|
||||||
|
down_revision = '532b3f888abf'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('timeout', sa.Float(), server_default=sa.text('30'), nullable=False))
|
||||||
|
batch_op.add_column(sa.Column('sse_read_timeout', sa.Float(), server_default=sa.text('300'), nullable=False))
|
||||||
|
|
||||||
|
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index(batch_op.f('workflow_node_execution_created_at_idx'))
|
||||||
|
|
||||||
|
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index(batch_op.f('workflow_run_created_at_idx'))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
|
||||||
|
batch_op.create_index(batch_op.f('workflow_run_created_at_idx'), ['created_at'], unique=False)
|
||||||
|
|
||||||
|
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
|
||||||
|
batch_op.create_index(batch_op.f('workflow_node_execution_created_at_idx'), ['created_at'], unique=False)
|
||||||
|
|
||||||
|
with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('sse_read_timeout')
|
||||||
|
batch_op.drop_column('timeout')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
@@ -278,6 +278,8 @@ class MCPToolProvider(Base):
|
|||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||||
)
|
)
|
||||||
|
timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30"))
|
||||||
|
sse_read_timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("300"))
|
||||||
|
|
||||||
def load_user(self) -> Account | None:
|
def load_user(self) -> Account | None:
|
||||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||||
|
@@ -59,6 +59,8 @@ class MCPToolManageService:
|
|||||||
icon_type: str,
|
icon_type: str,
|
||||||
icon_background: str,
|
icon_background: str,
|
||||||
server_identifier: str,
|
server_identifier: str,
|
||||||
|
timeout: float,
|
||||||
|
sse_read_timeout: float,
|
||||||
) -> ToolProviderApiEntity:
|
) -> ToolProviderApiEntity:
|
||||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||||
existing_provider = (
|
existing_provider = (
|
||||||
@@ -91,6 +93,8 @@ class MCPToolManageService:
|
|||||||
tools="[]",
|
tools="[]",
|
||||||
icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon,
|
icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon,
|
||||||
server_identifier=server_identifier,
|
server_identifier=server_identifier,
|
||||||
|
timeout=timeout,
|
||||||
|
sse_read_timeout=sse_read_timeout,
|
||||||
)
|
)
|
||||||
db.session.add(mcp_tool)
|
db.session.add(mcp_tool)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
@@ -166,6 +170,8 @@ class MCPToolManageService:
|
|||||||
icon_type: str,
|
icon_type: str,
|
||||||
icon_background: str,
|
icon_background: str,
|
||||||
server_identifier: str,
|
server_identifier: str,
|
||||||
|
timeout: float | None = None,
|
||||||
|
sse_read_timeout: float | None = None,
|
||||||
):
|
):
|
||||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||||
|
|
||||||
@@ -197,6 +203,10 @@ class MCPToolManageService:
|
|||||||
mcp_provider.tools = reconnect_result["tools"]
|
mcp_provider.tools = reconnect_result["tools"]
|
||||||
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
|
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
|
||||||
|
|
||||||
|
if timeout is not None:
|
||||||
|
mcp_provider.timeout = timeout
|
||||||
|
if sse_read_timeout is not None:
|
||||||
|
mcp_provider.sse_read_timeout = sse_read_timeout
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
|
Reference in New Issue
Block a user