feat: improved MCP timeout (#23546)
This commit is contained in:
@@ -862,6 +862,10 @@ class ToolProviderMCPApi(Resource):
|
||||
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
|
||||
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30)
|
||||
parser.add_argument(
|
||||
"sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300
|
||||
)
|
||||
args = parser.parse_args()
|
||||
user = current_user
|
||||
if not is_valid_url(args["server_url"]):
|
||||
@@ -876,6 +880,8 @@ class ToolProviderMCPApi(Resource):
|
||||
icon_background=args["icon_background"],
|
||||
user_id=user.id,
|
||||
server_identifier=args["server_identifier"],
|
||||
timeout=args["timeout"],
|
||||
sse_read_timeout=args["sse_read_timeout"],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -891,6 +897,8 @@ class ToolProviderMCPApi(Resource):
|
||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("timeout", type=float, required=False, nullable=True, location="json")
|
||||
parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
if not is_valid_url(args["server_url"]):
|
||||
if "[__HIDDEN__]" in args["server_url"]:
|
||||
@@ -906,6 +914,8 @@ class ToolProviderMCPApi(Resource):
|
||||
icon_type=args["icon_type"],
|
||||
icon_background=args["icon_background"],
|
||||
server_identifier=args["server_identifier"],
|
||||
timeout=args.get("timeout"),
|
||||
sse_read_timeout=args.get("sse_read_timeout"),
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
|
@@ -327,7 +327,7 @@ def send_message(http_client: httpx.Client, endpoint_url: str, session_message:
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug("Client message sent successfully: %s", response.status_code)
|
||||
except Exception as exc:
|
||||
except Exception:
|
||||
logger.exception("Error sending message")
|
||||
raise
|
||||
|
||||
|
@@ -55,14 +55,10 @@ DEFAULT_QUEUE_READ_TIMEOUT = 3
|
||||
class StreamableHTTPError(Exception):
|
||||
"""Base exception for StreamableHTTP transport errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ResumptionError(StreamableHTTPError):
|
||||
"""Raised when resumption request is invalid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext:
|
||||
@@ -74,7 +70,7 @@ class RequestContext:
|
||||
session_message: SessionMessage
|
||||
metadata: ClientMessageMetadata | None
|
||||
server_to_client_queue: ServerToClientQueue # Renamed for clarity
|
||||
sse_read_timeout: timedelta
|
||||
sse_read_timeout: float
|
||||
|
||||
|
||||
class StreamableHTTPTransport:
|
||||
@@ -84,8 +80,8 @@ class StreamableHTTPTransport:
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: timedelta = timedelta(seconds=30),
|
||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
|
||||
timeout: float | timedelta = 30,
|
||||
sse_read_timeout: float | timedelta = 60 * 5,
|
||||
) -> None:
|
||||
"""Initialize the StreamableHTTP transport.
|
||||
|
||||
@@ -97,8 +93,10 @@ class StreamableHTTPTransport:
|
||||
"""
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
|
||||
self.sse_read_timeout = (
|
||||
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
|
||||
)
|
||||
self.session_id: str | None = None
|
||||
self.request_headers = {
|
||||
ACCEPT: f"{JSON}, {SSE}",
|
||||
@@ -186,7 +184,7 @@ class StreamableHTTPTransport:
|
||||
with ssrf_proxy_sse_connect(
|
||||
self.url,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
|
||||
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
|
||||
client=client,
|
||||
method="GET",
|
||||
) as event_source:
|
||||
@@ -215,7 +213,7 @@ class StreamableHTTPTransport:
|
||||
with ssrf_proxy_sse_connect(
|
||||
self.url,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
|
||||
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
|
||||
client=ctx.client,
|
||||
method="GET",
|
||||
) as event_source:
|
||||
@@ -402,8 +400,8 @@ class StreamableHTTPTransport:
|
||||
def streamablehttp_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: timedelta = timedelta(seconds=30),
|
||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
|
||||
timeout: float | timedelta = 30,
|
||||
sse_read_timeout: float | timedelta = 60 * 5,
|
||||
terminate_on_close: bool = True,
|
||||
) -> Generator[
|
||||
tuple[
|
||||
@@ -436,7 +434,7 @@ def streamablehttp_client(
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(
|
||||
headers=transport.request_headers,
|
||||
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
|
||||
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
|
||||
) as client:
|
||||
# Define callbacks that need access to thread pool
|
||||
def start_get_stream() -> None:
|
||||
|
@@ -23,12 +23,18 @@ class MCPClient:
|
||||
authed: bool = True,
|
||||
authorization_code: Optional[str] = None,
|
||||
for_list: bool = False,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
sse_read_timeout: Optional[float] = None,
|
||||
):
|
||||
# Initialize info
|
||||
self.provider_id = provider_id
|
||||
self.tenant_id = tenant_id
|
||||
self.client_type = "streamable"
|
||||
self.server_url = server_url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
|
||||
# Authentication info
|
||||
self.authed = authed
|
||||
@@ -43,7 +49,7 @@ class MCPClient:
|
||||
self._session: Optional[ClientSession] = None
|
||||
self._streams_context: Optional[AbstractContextManager[Any]] = None
|
||||
self._session_context: Optional[ClientSession] = None
|
||||
self.exit_stack = ExitStack()
|
||||
self._exit_stack = ExitStack()
|
||||
|
||||
# Whether the client has been initialized
|
||||
self._initialized = False
|
||||
@@ -90,21 +96,26 @@ class MCPClient:
|
||||
headers = (
|
||||
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
|
||||
if self.authed and self.token
|
||||
else {}
|
||||
else self.headers
|
||||
)
|
||||
self._streams_context = client_factory(
|
||||
url=self.server_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
self._streams_context = client_factory(url=self.server_url, headers=headers)
|
||||
if not self._streams_context:
|
||||
raise MCPConnectionError("Failed to create connection context")
|
||||
|
||||
# Use exit_stack to manage context managers properly
|
||||
if method_name == "mcp":
|
||||
read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context)
|
||||
read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context)
|
||||
streams = (read_stream, write_stream)
|
||||
else: # sse_client
|
||||
streams = self.exit_stack.enter_context(self._streams_context)
|
||||
streams = self._exit_stack.enter_context(self._streams_context)
|
||||
|
||||
self._session_context = ClientSession(*streams)
|
||||
self._session = self.exit_stack.enter_context(self._session_context)
|
||||
self._session = self._exit_stack.enter_context(self._session_context)
|
||||
session = cast(ClientSession, self._session)
|
||||
session.initialize()
|
||||
return
|
||||
@@ -120,9 +131,6 @@ class MCPClient:
|
||||
if first_try:
|
||||
return self.connect_server(client_factory, method_name, first_try=False)
|
||||
|
||||
except MCPConnectionError:
|
||||
raise
|
||||
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""Connect to an MCP server running with SSE transport"""
|
||||
# List available tools to verify connection
|
||||
@@ -142,7 +150,7 @@ class MCPClient:
|
||||
"""Clean up resources"""
|
||||
try:
|
||||
# ExitStack will handle proper cleanup of all managed context managers
|
||||
self.exit_stack.close()
|
||||
self._exit_stack.close()
|
||||
except Exception as e:
|
||||
logging.exception("Error during cleanup")
|
||||
raise ValueError(f"Error during cleanup: {e}")
|
||||
|
@@ -2,7 +2,6 @@ import logging
|
||||
import queue
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
|
||||
from contextlib import ExitStack
|
||||
from datetime import timedelta
|
||||
from types import TracebackType
|
||||
from typing import Any, Generic, Self, TypeVar
|
||||
@@ -170,7 +169,6 @@ class BaseSession(
|
||||
self._receive_notification_type = receive_notification_type
|
||||
self._session_read_timeout_seconds = read_timeout_seconds
|
||||
self._in_flight = {}
|
||||
self._exit_stack = ExitStack()
|
||||
# Initialize executor and future to None for proper cleanup checks
|
||||
self._executor: ThreadPoolExecutor | None = None
|
||||
self._receiver_future: Future | None = None
|
||||
@@ -377,7 +375,7 @@ class BaseSession(
|
||||
self._handle_incoming(RuntimeError(f"Server Error: {message}"))
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.exception("Error in message processing loop")
|
||||
raise
|
||||
|
||||
@@ -389,14 +387,12 @@ class BaseSession(
|
||||
If the request is responded to within this method, it will not be
|
||||
forwarded on to the message stream.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _received_notification(self, notification: ReceiveNotificationT) -> None:
|
||||
"""
|
||||
Can be overridden by subclasses to handle a notification without needing
|
||||
to listen on the message stream.
|
||||
"""
|
||||
pass
|
||||
|
||||
def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
@@ -405,11 +401,9 @@ class BaseSession(
|
||||
Sends a progress notification for a request that is currently being
|
||||
processed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
|
||||
) -> None:
|
||||
"""A generic handler for incoming messages. Overwritten by subclasses."""
|
||||
pass
|
||||
|
@@ -1,3 +1,4 @@
|
||||
import queue
|
||||
from datetime import timedelta
|
||||
from typing import Any, Protocol
|
||||
|
||||
@@ -85,8 +86,8 @@ class ClientSession(
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
read_stream,
|
||||
write_stream,
|
||||
read_stream: queue.Queue,
|
||||
write_stream: queue.Queue,
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
sampling_callback: SamplingFnT | None = None,
|
||||
list_roots_callback: ListRootsFnT | None = None,
|
||||
|
@@ -12,8 +12,6 @@ from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class ToolProviderController(ABC):
|
||||
entity: ToolProviderEntity
|
||||
|
||||
def __init__(self, entity: ToolProviderEntity) -> None:
|
||||
self.entity = entity
|
||||
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.mcp.types import Tool as RemoteMCPTool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
@@ -19,15 +19,24 @@ from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
class MCPToolProviderController(ToolProviderController):
|
||||
provider_id: str
|
||||
entity: ToolProviderEntityWithPlugin
|
||||
|
||||
def __init__(self, entity: ToolProviderEntityWithPlugin, provider_id: str, tenant_id: str, server_url: str) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
entity: ToolProviderEntityWithPlugin,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
server_url: str,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
sse_read_timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
super().__init__(entity)
|
||||
self.entity = entity
|
||||
self.entity: ToolProviderEntityWithPlugin = entity
|
||||
self.tenant_id = tenant_id
|
||||
self.provider_id = provider_id
|
||||
self.server_url = server_url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
@@ -85,6 +94,9 @@ class MCPToolProviderController(ToolProviderController):
|
||||
provider_id=db_provider.server_identifier or "",
|
||||
tenant_id=db_provider.tenant_id or "",
|
||||
server_url=db_provider.decrypted_server_url,
|
||||
headers={}, # TODO: get headers from db provider
|
||||
timeout=db_provider.timeout,
|
||||
sse_read_timeout=db_provider.sse_read_timeout,
|
||||
)
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
@@ -111,6 +123,9 @@ class MCPToolProviderController(ToolProviderController):
|
||||
icon=self.entity.identity.icon,
|
||||
server_url=self.server_url,
|
||||
provider_id=self.provider_id,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
def get_tools(self) -> list[MCPTool]: # type: ignore
|
||||
@@ -125,6 +140,9 @@ class MCPToolProviderController(ToolProviderController):
|
||||
icon=self.entity.identity.icon,
|
||||
server_url=self.server_url,
|
||||
provider_id=self.provider_id,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
for tool_entity in self.entity.tools
|
||||
]
|
||||
|
@@ -13,13 +13,25 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too
|
||||
|
||||
class MCPTool(Tool):
|
||||
def __init__(
|
||||
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str
|
||||
self,
|
||||
entity: ToolEntity,
|
||||
runtime: ToolRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
server_url: str,
|
||||
provider_id: str,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
sse_read_timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.server_url = server_url
|
||||
self.provider_id = provider_id
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.MCP
|
||||
@@ -35,7 +47,15 @@ class MCPTool(Tool):
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
try:
|
||||
with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client:
|
||||
with MCPClient(
|
||||
self.server_url,
|
||||
self.provider_id,
|
||||
self.tenant_id,
|
||||
authed=True,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
) as mcp_client:
|
||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPAuthError as e:
|
||||
@@ -72,6 +92,9 @@ class MCPTool(Tool):
|
||||
icon=self.icon,
|
||||
server_url=self.server_url,
|
||||
provider_id=self.provider_id,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:
|
||||
|
@@ -789,9 +789,6 @@ class ToolManager:
|
||||
"""
|
||||
get api provider
|
||||
"""
|
||||
"""
|
||||
get tool provider
|
||||
"""
|
||||
provider_name = provider
|
||||
provider_obj: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
|
@@ -0,0 +1,47 @@
|
||||
"""add timeout for tool_mcp_providers
|
||||
|
||||
Revision ID: fa8b0fa6f407
|
||||
Revises: 532b3f888abf
|
||||
Create Date: 2025-08-07 11:15:31.517985
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'fa8b0fa6f407'
|
||||
down_revision = '532b3f888abf'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('timeout', sa.Float(), server_default=sa.text('30'), nullable=False))
|
||||
batch_op.add_column(sa.Column('sse_read_timeout', sa.Float(), server_default=sa.text('300'), nullable=False))
|
||||
|
||||
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('workflow_node_execution_created_at_idx'))
|
||||
|
||||
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('workflow_run_created_at_idx'))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('workflow_run_created_at_idx'), ['created_at'], unique=False)
|
||||
|
||||
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('workflow_node_execution_created_at_idx'), ['created_at'], unique=False)
|
||||
|
||||
with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
|
||||
batch_op.drop_column('sse_read_timeout')
|
||||
batch_op.drop_column('timeout')
|
||||
|
||||
# ### end Alembic commands ###
|
@@ -278,6 +278,8 @@ class MCPToolProvider(Base):
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30"))
|
||||
sse_read_timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("300"))
|
||||
|
||||
def load_user(self) -> Account | None:
|
||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||
|
@@ -59,6 +59,8 @@ class MCPToolManageService:
|
||||
icon_type: str,
|
||||
icon_background: str,
|
||||
server_identifier: str,
|
||||
timeout: float,
|
||||
sse_read_timeout: float,
|
||||
) -> ToolProviderApiEntity:
|
||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||
existing_provider = (
|
||||
@@ -91,6 +93,8 @@ class MCPToolManageService:
|
||||
tools="[]",
|
||||
icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon,
|
||||
server_identifier=server_identifier,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
)
|
||||
db.session.add(mcp_tool)
|
||||
db.session.commit()
|
||||
@@ -166,6 +170,8 @@ class MCPToolManageService:
|
||||
icon_type: str,
|
||||
icon_background: str,
|
||||
server_identifier: str,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
):
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
|
||||
@@ -197,6 +203,10 @@ class MCPToolManageService:
|
||||
mcp_provider.tools = reconnect_result["tools"]
|
||||
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
|
||||
|
||||
if timeout is not None:
|
||||
mcp_provider.timeout = timeout
|
||||
if sse_read_timeout is not None:
|
||||
mcp_provider.sse_read_timeout = sse_read_timeout
|
||||
db.session.commit()
|
||||
except IntegrityError as e:
|
||||
db.session.rollback()
|
||||
|
Reference in New Issue
Block a user