From 1c60b7f0704c5b718c37bd71ca05b520f9a1d24b Mon Sep 17 00:00:00 2001 From: crazywoola <100913391+crazywoola@users.noreply.github.com> Date: Thu, 7 Aug 2025 20:20:53 -0700 Subject: [PATCH] Revert "feat: improved MCP timeout" (#23602) --- .../console/workspace/tool_providers.py | 10 ---- api/core/mcp/client/sse_client.py | 2 +- api/core/mcp/client/streamable_client.py | 26 +++++----- api/core/mcp/mcp_client.py | 28 ++++------- api/core/mcp/session/base_session.py | 8 +++- api/core/mcp/session/client_session.py | 5 +- api/core/tools/__base/tool_provider.py | 2 + api/core/tools/mcp_tool/provider.py | 30 +++--------- api/core/tools/mcp_tool/tool.py | 27 +---------- api/core/tools/tool_manager.py | 3 ++ ...f407_add_timeout_for_tool_mcp_providers.py | 47 ------------------- api/models/tools.py | 2 - .../tools/mcp_tools_manage_service.py | 10 ---- 13 files changed, 47 insertions(+), 153 deletions(-) delete mode 100644 api/migrations/versions/2025_08_07_1115-fa8b0fa6f407_add_timeout_for_tool_mcp_providers.py diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 8c8b73b45..c4d1ef70d 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -862,10 +862,6 @@ class ToolProviderMCPApi(Resource): parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") - parser.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30) - parser.add_argument( - "sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300 - ) args = parser.parse_args() user = current_user if not is_valid_url(args["server_url"]): @@ -880,8 +876,6 @@ class ToolProviderMCPApi(Resource): icon_background=args["icon_background"], user_id=user.id, server_identifier=args["server_identifier"], - timeout=args["timeout"], - sse_read_timeout=args["sse_read_timeout"], ) ) @@ -897,8 +891,6 @@ class ToolProviderMCPApi(Resource): parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") - parser.add_argument("timeout", type=float, required=False, nullable=True, location="json") - parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json") args = parser.parse_args() if not is_valid_url(args["server_url"]): if "[__HIDDEN__]" in args["server_url"]: @@ -914,8 +906,6 @@ class ToolProviderMCPApi(Resource): icon_type=args["icon_type"], icon_background=args["icon_background"], server_identifier=args["server_identifier"], - timeout=args.get("timeout"), - sse_read_timeout=args.get("sse_read_timeout"), ) return {"result": "success"} diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py index 2d3a3f534..4226e77f7 100644 --- a/api/core/mcp/client/sse_client.py +++ b/api/core/mcp/client/sse_client.py @@ -327,7 +327,7 @@ def send_message(http_client: httpx.Client, endpoint_url: str, session_message: ) response.raise_for_status() logger.debug("Client message sent successfully: %s", response.status_code) - except Exception: + except Exception as exc: logger.exception("Error sending message") raise diff --git a/api/core/mcp/client/streamable_client.py b/api/core/mcp/client/streamable_client.py index 14e346c2f..ca414ebb9 100644 --- a/api/core/mcp/client/streamable_client.py +++ b/api/core/mcp/client/streamable_client.py @@ -55,10 +55,14 @@ DEFAULT_QUEUE_READ_TIMEOUT = 3 class StreamableHTTPError(Exception): """Base exception for StreamableHTTP transport errors.""" + pass + class ResumptionError(StreamableHTTPError): """Raised when resumption request is invalid.""" + pass + @dataclass class RequestContext: @@ -70,7 +74,7 @@ class RequestContext: session_message: SessionMessage metadata: ClientMessageMetadata | None server_to_client_queue: ServerToClientQueue # Renamed for clarity - sse_read_timeout: float + sse_read_timeout: timedelta class StreamableHTTPTransport: @@ -80,8 +84,8 @@ class StreamableHTTPTransport: self, url: str, headers: dict[str, Any] | None = None, - timeout: float | timedelta = 30, - sse_read_timeout: float | timedelta = 60 * 5, + timeout: timedelta = timedelta(seconds=30), + sse_read_timeout: timedelta = timedelta(seconds=60 * 5), ) -> None: """Initialize the StreamableHTTP transport. @@ -93,10 +97,8 @@ class StreamableHTTPTransport: """ self.url = url self.headers = headers or {} - self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout - self.sse_read_timeout = ( - sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout - ) + self.timeout = timeout + self.sse_read_timeout = sse_read_timeout self.session_id: str | None = None self.request_headers = { ACCEPT: f"{JSON}, {SSE}", @@ -184,7 +186,7 @@ class StreamableHTTPTransport: with ssrf_proxy_sse_connect( self.url, headers=headers, - timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), + timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds), client=client, method="GET", ) as event_source: @@ -213,7 +215,7 @@ class StreamableHTTPTransport: with ssrf_proxy_sse_connect( self.url, headers=headers, - timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), + timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds), client=ctx.client, method="GET", ) as event_source: @@ -400,8 +402,8 @@ class StreamableHTTPTransport: def streamablehttp_client( url: str, headers: dict[str, Any] | None = None, - timeout: float | timedelta = 30, - sse_read_timeout: float | timedelta = 60 * 5, + timeout: timedelta = timedelta(seconds=30), + sse_read_timeout: timedelta = timedelta(seconds=60 * 5), terminate_on_close: bool = True, ) -> Generator[ tuple[ @@ -434,7 +436,7 @@ def streamablehttp_client( try: with create_ssrf_proxy_mcp_http_client( headers=transport.request_headers, - timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), + timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds), ) as client: # Define callbacks that need access to thread pool def start_get_stream() -> None: diff --git a/api/core/mcp/mcp_client.py b/api/core/mcp/mcp_client.py index 7d90d5195..875d13de0 100644 --- a/api/core/mcp/mcp_client.py +++ b/api/core/mcp/mcp_client.py @@ -23,18 +23,12 @@ class MCPClient: authed: bool = True, authorization_code: Optional[str] = None, for_list: bool = False, - headers: Optional[dict[str, str]] = None, - timeout: Optional[float] = None, - sse_read_timeout: Optional[float] = None, ): # Initialize info self.provider_id = provider_id self.tenant_id = tenant_id self.client_type = "streamable" self.server_url = server_url - self.headers = headers or {} - self.timeout = timeout - self.sse_read_timeout = sse_read_timeout # Authentication info self.authed = authed @@ -49,7 +43,7 @@ class MCPClient: self._session: Optional[ClientSession] = None self._streams_context: Optional[AbstractContextManager[Any]] = None self._session_context: Optional[ClientSession] = None - self._exit_stack = ExitStack() + self.exit_stack = ExitStack() # Whether the client has been initialized self._initialized = False @@ -96,26 +90,21 @@ class MCPClient: headers = ( {"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"} if self.authed and self.token - else self.headers - ) - self._streams_context = client_factory( - url=self.server_url, - headers=headers, - timeout=self.timeout, - sse_read_timeout=self.sse_read_timeout, + else {} ) + self._streams_context = client_factory(url=self.server_url, headers=headers) if not self._streams_context: raise MCPConnectionError("Failed to create connection context") # Use exit_stack to manage context managers properly if method_name == "mcp": - read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context) + read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context) streams = (read_stream, write_stream) else: # sse_client - streams = self._exit_stack.enter_context(self._streams_context) + streams = self.exit_stack.enter_context(self._streams_context) self._session_context = ClientSession(*streams) - self._session = self._exit_stack.enter_context(self._session_context) + self._session = self.exit_stack.enter_context(self._session_context) session = cast(ClientSession, self._session) session.initialize() return @@ -131,6 +120,9 @@ class MCPClient: if first_try: return self.connect_server(client_factory, method_name, first_try=False) + except MCPConnectionError: + raise + def list_tools(self) -> list[Tool]: """Connect to an MCP server running with SSE transport""" # List available tools to verify connection @@ -150,7 +142,7 @@ class MCPClient: """Clean up resources""" try: # ExitStack will handle proper cleanup of all managed context managers - self._exit_stack.close() + self.exit_stack.close() except Exception as e: logging.exception("Error during cleanup") raise ValueError(f"Error during cleanup: {e}") diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 3f98aa94a..3b6c9a742 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -2,6 +2,7 @@ import logging import queue from collections.abc import Callable from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError +from contextlib import ExitStack from datetime import timedelta from types import TracebackType from typing import Any, Generic, Self, TypeVar @@ -169,6 +170,7 @@ class BaseSession( self._receive_notification_type = receive_notification_type self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} + self._exit_stack = ExitStack() # Initialize executor and future to None for proper cleanup checks self._executor: ThreadPoolExecutor | None = None self._receiver_future: Future | None = None @@ -375,7 +377,7 @@ class BaseSession( self._handle_incoming(RuntimeError(f"Server Error: {message}")) except queue.Empty: continue - except Exception: + except Exception as e: logging.exception("Error in message processing loop") raise @@ -387,12 +389,14 @@ class BaseSession( If the request is responded to within this method, it will not be forwarded on to the message stream. """ + pass def _received_notification(self, notification: ReceiveNotificationT) -> None: """ Can be overridden by subclasses to handle a notification without needing to listen on the message stream. """ + pass def send_progress_notification( self, progress_token: str | int, progress: float, total: float | None = None @@ -401,9 +405,11 @@ class BaseSession( Sends a progress notification for a request that is currently being processed. """ + pass def _handle_incoming( self, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, ) -> None: """A generic handler for incoming messages. Overwritten by subclasses.""" + pass diff --git a/api/core/mcp/session/client_session.py b/api/core/mcp/session/client_session.py index 1bccf1d03..ed2ad508a 100644 --- a/api/core/mcp/session/client_session.py +++ b/api/core/mcp/session/client_session.py @@ -1,4 +1,3 @@ -import queue from datetime import timedelta from typing import Any, Protocol @@ -86,8 +85,8 @@ class ClientSession( ): def __init__( self, - read_stream: queue.Queue, - write_stream: queue.Queue, + read_stream, + write_stream, read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, diff --git a/api/core/tools/__base/tool_provider.py b/api/core/tools/__base/tool_provider.py index d1d7976cc..d096fc7df 100644 --- a/api/core/tools/__base/tool_provider.py +++ b/api/core/tools/__base/tool_provider.py @@ -12,6 +12,8 @@ from core.tools.errors import ToolProviderCredentialValidationError class ToolProviderController(ABC): + entity: ToolProviderEntity + def __init__(self, entity: ToolProviderEntity) -> None: self.entity = entity diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py index 24ee981a1..93f003eff 100644 --- a/api/core/tools/mcp_tool/provider.py +++ b/api/core/tools/mcp_tool/provider.py @@ -1,5 +1,5 @@ import json -from typing import Any, Optional +from typing import Any from core.mcp.types import Tool as RemoteMCPTool from core.tools.__base.tool_provider import ToolProviderController @@ -19,24 +19,15 @@ from services.tools.tools_transform_service import ToolTransformService class MCPToolProviderController(ToolProviderController): - def __init__( - self, - entity: ToolProviderEntityWithPlugin, - provider_id: str, - tenant_id: str, - server_url: str, - headers: Optional[dict[str, str]] = None, - timeout: Optional[float] = None, - sse_read_timeout: Optional[float] = None, - ) -> None: + provider_id: str + entity: ToolProviderEntityWithPlugin + + def __init__(self, entity: ToolProviderEntityWithPlugin, provider_id: str, tenant_id: str, server_url: str) -> None: super().__init__(entity) - self.entity: ToolProviderEntityWithPlugin = entity + self.entity = entity self.tenant_id = tenant_id self.provider_id = provider_id self.server_url = server_url - self.headers = headers or {} - self.timeout = timeout - self.sse_read_timeout = sse_read_timeout @property def provider_type(self) -> ToolProviderType: @@ -94,9 +85,6 @@ class MCPToolProviderController(ToolProviderController): provider_id=db_provider.server_identifier or "", tenant_id=db_provider.tenant_id or "", server_url=db_provider.decrypted_server_url, - headers={}, # TODO: get headers from db provider - timeout=db_provider.timeout, - sse_read_timeout=db_provider.sse_read_timeout, ) def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: @@ -123,9 +111,6 @@ class MCPToolProviderController(ToolProviderController): icon=self.entity.identity.icon, server_url=self.server_url, provider_id=self.provider_id, - headers=self.headers, - timeout=self.timeout, - sse_read_timeout=self.sse_read_timeout, ) def get_tools(self) -> list[MCPTool]: # type: ignore @@ -140,9 +125,6 @@ class MCPToolProviderController(ToolProviderController): icon=self.entity.identity.icon, server_url=self.server_url, provider_id=self.provider_id, - headers=self.headers, - timeout=self.timeout, - sse_read_timeout=self.sse_read_timeout, ) for tool_entity in self.entity.tools ] diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 26789b23c..8ebbb6b0f 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -13,25 +13,13 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too class MCPTool(Tool): def __init__( - self, - entity: ToolEntity, - runtime: ToolRuntime, - tenant_id: str, - icon: str, - server_url: str, - provider_id: str, - headers: Optional[dict[str, str]] = None, - timeout: Optional[float] = None, - sse_read_timeout: Optional[float] = None, + self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str ) -> None: super().__init__(entity, runtime) self.tenant_id = tenant_id self.icon = icon self.server_url = server_url self.provider_id = provider_id - self.headers = headers or {} - self.timeout = timeout - self.sse_read_timeout = sse_read_timeout def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.MCP @@ -47,15 +35,7 @@ class MCPTool(Tool): from core.tools.errors import ToolInvokeError try: - with MCPClient( - self.server_url, - self.provider_id, - self.tenant_id, - authed=True, - headers=self.headers, - timeout=self.timeout, - sse_read_timeout=self.sse_read_timeout, - ) as mcp_client: + with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client: tool_parameters = self._handle_none_parameter(tool_parameters) result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) except MCPAuthError as e: @@ -92,9 +72,6 @@ class MCPTool(Tool): icon=self.icon, server_url=self.server_url, provider_id=self.provider_id, - headers=self.headers, - timeout=self.timeout, - sse_read_timeout=self.sse_read_timeout, ) def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 7472f4f60..2737bcfb1 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -789,6 +789,9 @@ class ToolManager: """ get api provider """ + """ + get tool provider + """ provider_name = provider provider_obj: ApiToolProvider | None = ( db.session.query(ApiToolProvider) diff --git a/api/migrations/versions/2025_08_07_1115-fa8b0fa6f407_add_timeout_for_tool_mcp_providers.py b/api/migrations/versions/2025_08_07_1115-fa8b0fa6f407_add_timeout_for_tool_mcp_providers.py deleted file mode 100644 index eabead232..000000000 --- a/api/migrations/versions/2025_08_07_1115-fa8b0fa6f407_add_timeout_for_tool_mcp_providers.py +++ /dev/null @@ -1,47 +0,0 @@ -"""add timeout for tool_mcp_providers - -Revision ID: fa8b0fa6f407 -Revises: 532b3f888abf -Create Date: 2025-08-07 11:15:31.517985 - -""" -from alembic import op -import models as models -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = 'fa8b0fa6f407' -down_revision = '532b3f888abf' -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op: - batch_op.add_column(sa.Column('timeout', sa.Float(), server_default=sa.text('30'), nullable=False)) - batch_op.add_column(sa.Column('sse_read_timeout', sa.Float(), server_default=sa.text('300'), nullable=False)) - - with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: - batch_op.drop_index(batch_op.f('workflow_node_execution_created_at_idx')) - - with op.batch_alter_table('workflow_runs', schema=None) as batch_op: - batch_op.drop_index(batch_op.f('workflow_run_created_at_idx')) - - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('workflow_runs', schema=None) as batch_op: - batch_op.create_index(batch_op.f('workflow_run_created_at_idx'), ['created_at'], unique=False) - - with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: - batch_op.create_index(batch_op.f('workflow_node_execution_created_at_idx'), ['created_at'], unique=False) - - with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op: - batch_op.drop_column('sse_read_timeout') - batch_op.drop_column('timeout') - - # ### end Alembic commands ### diff --git a/api/models/tools.py b/api/models/tools.py index e0c9fa6ff..408c1371c 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -278,8 +278,6 @@ class MCPToolProvider(Base): updated_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) - timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30")) - sse_read_timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("300")) def load_user(self) -> Account | None: return db.session.query(Account).where(Account.id == self.user_id).first() diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index f45c93176..23be449a5 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -59,8 +59,6 @@ class MCPToolManageService: icon_type: str, icon_background: str, server_identifier: str, - timeout: float, - sse_read_timeout: float, ) -> ToolProviderApiEntity: server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() existing_provider = ( @@ -93,8 +91,6 @@ class MCPToolManageService: tools="[]", icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon, server_identifier=server_identifier, - timeout=timeout, - sse_read_timeout=sse_read_timeout, ) db.session.add(mcp_tool) db.session.commit() @@ -170,8 +166,6 @@ class MCPToolManageService: icon_type: str, icon_background: str, server_identifier: str, - timeout: float | None = None, - sse_read_timeout: float | None = None, ): mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) @@ -203,10 +197,6 @@ class MCPToolManageService: mcp_provider.tools = reconnect_result["tools"] mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"] - if timeout is not None: - mcp_provider.timeout = timeout - if sse_read_timeout is not None: - mcp_provider.sse_read_timeout = sse_read_timeout db.session.commit() except IntegrityError as e: db.session.rollback()