Revert "feat: improved MCP timeout" (#23602)

This commit is contained in:
crazywoola
2025-08-07 20:20:53 -07:00
committed by GitHub
parent 084dcd1a50
commit 1c60b7f070
13 changed files with 47 additions and 153 deletions

View File

@@ -327,7 +327,7 @@ def send_message(http_client: httpx.Client, endpoint_url: str, session_message:
)
response.raise_for_status()
logger.debug("Client message sent successfully: %s", response.status_code)
except Exception:
except Exception as exc:
logger.exception("Error sending message")
raise

View File

@@ -55,10 +55,14 @@ DEFAULT_QUEUE_READ_TIMEOUT = 3
class StreamableHTTPError(Exception):
"""Base exception for StreamableHTTP transport errors."""
pass
class ResumptionError(StreamableHTTPError):
"""Raised when resumption request is invalid."""
pass
@dataclass
class RequestContext:
@@ -70,7 +74,7 @@ class RequestContext:
session_message: SessionMessage
metadata: ClientMessageMetadata | None
server_to_client_queue: ServerToClientQueue # Renamed for clarity
sse_read_timeout: float
sse_read_timeout: timedelta
class StreamableHTTPTransport:
@@ -80,8 +84,8 @@ class StreamableHTTPTransport:
self,
url: str,
headers: dict[str, Any] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
timeout: timedelta = timedelta(seconds=30),
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
) -> None:
"""Initialize the StreamableHTTP transport.
@@ -93,10 +97,8 @@ class StreamableHTTPTransport:
"""
self.url = url
self.headers = headers or {}
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
self.sse_read_timeout = (
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
)
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.session_id: str | None = None
self.request_headers = {
ACCEPT: f"{JSON}, {SSE}",
@@ -184,7 +186,7 @@ class StreamableHTTPTransport:
with ssrf_proxy_sse_connect(
self.url,
headers=headers,
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
client=client,
method="GET",
) as event_source:
@@ -213,7 +215,7 @@ class StreamableHTTPTransport:
with ssrf_proxy_sse_connect(
self.url,
headers=headers,
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
client=ctx.client,
method="GET",
) as event_source:
@@ -400,8 +402,8 @@ class StreamableHTTPTransport:
def streamablehttp_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
timeout: timedelta = timedelta(seconds=30),
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
terminate_on_close: bool = True,
) -> Generator[
tuple[
@@ -434,7 +436,7 @@ def streamablehttp_client(
try:
with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
) as client:
# Define callbacks that need access to thread pool
def start_get_stream() -> None:

View File

@@ -23,18 +23,12 @@ class MCPClient:
authed: bool = True,
authorization_code: Optional[str] = None,
for_list: bool = False,
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = None,
sse_read_timeout: Optional[float] = None,
):
# Initialize info
self.provider_id = provider_id
self.tenant_id = tenant_id
self.client_type = "streamable"
self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
# Authentication info
self.authed = authed
@@ -49,7 +43,7 @@ class MCPClient:
self._session: Optional[ClientSession] = None
self._streams_context: Optional[AbstractContextManager[Any]] = None
self._session_context: Optional[ClientSession] = None
self._exit_stack = ExitStack()
self.exit_stack = ExitStack()
# Whether the client has been initialized
self._initialized = False
@@ -96,26 +90,21 @@ class MCPClient:
headers = (
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
if self.authed and self.token
else self.headers
)
self._streams_context = client_factory(
url=self.server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
else {}
)
self._streams_context = client_factory(url=self.server_url, headers=headers)
if not self._streams_context:
raise MCPConnectionError("Failed to create connection context")
# Use exit_stack to manage context managers properly
if method_name == "mcp":
read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context)
read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context)
streams = (read_stream, write_stream)
else: # sse_client
streams = self._exit_stack.enter_context(self._streams_context)
streams = self.exit_stack.enter_context(self._streams_context)
self._session_context = ClientSession(*streams)
self._session = self._exit_stack.enter_context(self._session_context)
self._session = self.exit_stack.enter_context(self._session_context)
session = cast(ClientSession, self._session)
session.initialize()
return
@@ -131,6 +120,9 @@ class MCPClient:
if first_try:
return self.connect_server(client_factory, method_name, first_try=False)
except MCPConnectionError:
raise
def list_tools(self) -> list[Tool]:
"""Connect to an MCP server running with SSE transport"""
# List available tools to verify connection
@@ -150,7 +142,7 @@ class MCPClient:
"""Clean up resources"""
try:
# ExitStack will handle proper cleanup of all managed context managers
self._exit_stack.close()
self.exit_stack.close()
except Exception as e:
logging.exception("Error during cleanup")
raise ValueError(f"Error during cleanup: {e}")

View File

@@ -2,6 +2,7 @@ import logging
import queue
from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from contextlib import ExitStack
from datetime import timedelta
from types import TracebackType
from typing import Any, Generic, Self, TypeVar
@@ -169,6 +170,7 @@ class BaseSession(
self._receive_notification_type = receive_notification_type
self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._exit_stack = ExitStack()
# Initialize executor and future to None for proper cleanup checks
self._executor: ThreadPoolExecutor | None = None
self._receiver_future: Future | None = None
@@ -375,7 +377,7 @@ class BaseSession(
self._handle_incoming(RuntimeError(f"Server Error: {message}"))
except queue.Empty:
continue
except Exception:
except Exception as e:
logging.exception("Error in message processing loop")
raise
@@ -387,12 +389,14 @@ class BaseSession(
If the request is responded to within this method, it will not be
forwarded on to the message stream.
"""
pass
def _received_notification(self, notification: ReceiveNotificationT) -> None:
"""
Can be overridden by subclasses to handle a notification without needing
to listen on the message stream.
"""
pass
def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
@@ -401,9 +405,11 @@ class BaseSession(
Sends a progress notification for a request that is currently being
processed.
"""
pass
def _handle_incoming(
self,
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
) -> None:
"""A generic handler for incoming messages. Overwritten by subclasses."""
pass

View File

@@ -1,4 +1,3 @@
import queue
from datetime import timedelta
from typing import Any, Protocol
@@ -86,8 +85,8 @@ class ClientSession(
):
def __init__(
self,
read_stream: queue.Queue,
write_stream: queue.Queue,
read_stream,
write_stream,
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,