feat: improved MCP timeout (#23546)

This commit is contained in:
Will
2025-08-08 09:08:14 +08:00
committed by GitHub
parent c8c591d73c
commit 4b0480c8b3
13 changed files with 153 additions and 47 deletions

View File

@@ -23,12 +23,18 @@ class MCPClient:
authed: bool = True,
authorization_code: Optional[str] = None,
for_list: bool = False,
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = None,
sse_read_timeout: Optional[float] = None,
):
# Initialize info
self.provider_id = provider_id
self.tenant_id = tenant_id
self.client_type = "streamable"
self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
# Authentication info
self.authed = authed
@@ -43,7 +49,7 @@ class MCPClient:
self._session: Optional[ClientSession] = None
self._streams_context: Optional[AbstractContextManager[Any]] = None
self._session_context: Optional[ClientSession] = None
self.exit_stack = ExitStack()
self._exit_stack = ExitStack()
# Whether the client has been initialized
self._initialized = False
@@ -90,21 +96,26 @@ class MCPClient:
headers = (
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
if self.authed and self.token
else {}
else self.headers
)
self._streams_context = client_factory(
url=self.server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
self._streams_context = client_factory(url=self.server_url, headers=headers)
if not self._streams_context:
raise MCPConnectionError("Failed to create connection context")
# Use exit_stack to manage context managers properly
if method_name == "mcp":
read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context)
read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context)
streams = (read_stream, write_stream)
else: # sse_client
streams = self.exit_stack.enter_context(self._streams_context)
streams = self._exit_stack.enter_context(self._streams_context)
self._session_context = ClientSession(*streams)
self._session = self.exit_stack.enter_context(self._session_context)
self._session = self._exit_stack.enter_context(self._session_context)
session = cast(ClientSession, self._session)
session.initialize()
return
@@ -120,9 +131,6 @@ class MCPClient:
if first_try:
return self.connect_server(client_factory, method_name, first_try=False)
except MCPConnectionError:
raise
def list_tools(self) -> list[Tool]:
"""Connect to an MCP server running with SSE transport"""
# List available tools to verify connection
@@ -142,7 +150,7 @@ class MCPClient:
"""Clean up resources"""
try:
# ExitStack will handle proper cleanup of all managed context managers
self.exit_stack.close()
self._exit_stack.close()
except Exception as e:
logging.exception("Error during cleanup")
raise ValueError(f"Error during cleanup: {e}")