feat: improved MCP timeout (#23546)

This commit is contained in:
Will
2025-08-08 09:08:14 +08:00
committed by GitHub
parent c8c591d73c
commit 4b0480c8b3
13 changed files with 153 additions and 47 deletions

View File

@@ -862,6 +862,10 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30)
parser.add_argument(
"sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300
)
args = parser.parse_args() args = parser.parse_args()
user = current_user user = current_user
if not is_valid_url(args["server_url"]): if not is_valid_url(args["server_url"]):
@@ -876,6 +880,8 @@ class ToolProviderMCPApi(Resource):
icon_background=args["icon_background"], icon_background=args["icon_background"],
user_id=user.id, user_id=user.id,
server_identifier=args["server_identifier"], server_identifier=args["server_identifier"],
timeout=args["timeout"],
sse_read_timeout=args["sse_read_timeout"],
) )
) )
@@ -891,6 +897,8 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument("timeout", type=float, required=False, nullable=True, location="json")
parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
if not is_valid_url(args["server_url"]): if not is_valid_url(args["server_url"]):
if "[__HIDDEN__]" in args["server_url"]: if "[__HIDDEN__]" in args["server_url"]:
@@ -906,6 +914,8 @@ class ToolProviderMCPApi(Resource):
icon_type=args["icon_type"], icon_type=args["icon_type"],
icon_background=args["icon_background"], icon_background=args["icon_background"],
server_identifier=args["server_identifier"], server_identifier=args["server_identifier"],
timeout=args.get("timeout"),
sse_read_timeout=args.get("sse_read_timeout"),
) )
return {"result": "success"} return {"result": "success"}

View File

@@ -327,7 +327,7 @@ def send_message(http_client: httpx.Client, endpoint_url: str, session_message:
) )
response.raise_for_status() response.raise_for_status()
logger.debug("Client message sent successfully: %s", response.status_code) logger.debug("Client message sent successfully: %s", response.status_code)
except Exception as exc: except Exception:
logger.exception("Error sending message") logger.exception("Error sending message")
raise raise

View File

@@ -55,14 +55,10 @@ DEFAULT_QUEUE_READ_TIMEOUT = 3
class StreamableHTTPError(Exception): class StreamableHTTPError(Exception):
"""Base exception for StreamableHTTP transport errors.""" """Base exception for StreamableHTTP transport errors."""
pass
class ResumptionError(StreamableHTTPError): class ResumptionError(StreamableHTTPError):
"""Raised when resumption request is invalid.""" """Raised when resumption request is invalid."""
pass
@dataclass @dataclass
class RequestContext: class RequestContext:
@@ -74,7 +70,7 @@ class RequestContext:
session_message: SessionMessage session_message: SessionMessage
metadata: ClientMessageMetadata | None metadata: ClientMessageMetadata | None
server_to_client_queue: ServerToClientQueue # Renamed for clarity server_to_client_queue: ServerToClientQueue # Renamed for clarity
sse_read_timeout: timedelta sse_read_timeout: float
class StreamableHTTPTransport: class StreamableHTTPTransport:
@@ -84,8 +80,8 @@ class StreamableHTTPTransport:
self, self,
url: str, url: str,
headers: dict[str, Any] | None = None, headers: dict[str, Any] | None = None,
timeout: timedelta = timedelta(seconds=30), timeout: float | timedelta = 30,
sse_read_timeout: timedelta = timedelta(seconds=60 * 5), sse_read_timeout: float | timedelta = 60 * 5,
) -> None: ) -> None:
"""Initialize the StreamableHTTP transport. """Initialize the StreamableHTTP transport.
@@ -97,8 +93,10 @@ class StreamableHTTPTransport:
""" """
self.url = url self.url = url
self.headers = headers or {} self.headers = headers or {}
self.timeout = timeout self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
self.sse_read_timeout = sse_read_timeout self.sse_read_timeout = (
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
)
self.session_id: str | None = None self.session_id: str | None = None
self.request_headers = { self.request_headers = {
ACCEPT: f"{JSON}, {SSE}", ACCEPT: f"{JSON}, {SSE}",
@@ -186,7 +184,7 @@ class StreamableHTTPTransport:
with ssrf_proxy_sse_connect( with ssrf_proxy_sse_connect(
self.url, self.url,
headers=headers, headers=headers,
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds), timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
client=client, client=client,
method="GET", method="GET",
) as event_source: ) as event_source:
@@ -215,7 +213,7 @@ class StreamableHTTPTransport:
with ssrf_proxy_sse_connect( with ssrf_proxy_sse_connect(
self.url, self.url,
headers=headers, headers=headers,
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds), timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
client=ctx.client, client=ctx.client,
method="GET", method="GET",
) as event_source: ) as event_source:
@@ -402,8 +400,8 @@ class StreamableHTTPTransport:
def streamablehttp_client( def streamablehttp_client(
url: str, url: str,
headers: dict[str, Any] | None = None, headers: dict[str, Any] | None = None,
timeout: timedelta = timedelta(seconds=30), timeout: float | timedelta = 30,
sse_read_timeout: timedelta = timedelta(seconds=60 * 5), sse_read_timeout: float | timedelta = 60 * 5,
terminate_on_close: bool = True, terminate_on_close: bool = True,
) -> Generator[ ) -> Generator[
tuple[ tuple[
@@ -436,7 +434,7 @@ def streamablehttp_client(
try: try:
with create_ssrf_proxy_mcp_http_client( with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers, headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds), timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
) as client: ) as client:
# Define callbacks that need access to thread pool # Define callbacks that need access to thread pool
def start_get_stream() -> None: def start_get_stream() -> None:

View File

@@ -23,12 +23,18 @@ class MCPClient:
authed: bool = True, authed: bool = True,
authorization_code: Optional[str] = None, authorization_code: Optional[str] = None,
for_list: bool = False, for_list: bool = False,
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = None,
sse_read_timeout: Optional[float] = None,
): ):
# Initialize info # Initialize info
self.provider_id = provider_id self.provider_id = provider_id
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.client_type = "streamable" self.client_type = "streamable"
self.server_url = server_url self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
# Authentication info # Authentication info
self.authed = authed self.authed = authed
@@ -43,7 +49,7 @@ class MCPClient:
self._session: Optional[ClientSession] = None self._session: Optional[ClientSession] = None
self._streams_context: Optional[AbstractContextManager[Any]] = None self._streams_context: Optional[AbstractContextManager[Any]] = None
self._session_context: Optional[ClientSession] = None self._session_context: Optional[ClientSession] = None
self.exit_stack = ExitStack() self._exit_stack = ExitStack()
# Whether the client has been initialized # Whether the client has been initialized
self._initialized = False self._initialized = False
@@ -90,21 +96,26 @@ class MCPClient:
headers = ( headers = (
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"} {"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
if self.authed and self.token if self.authed and self.token
else {} else self.headers
)
self._streams_context = client_factory(
url=self.server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
) )
self._streams_context = client_factory(url=self.server_url, headers=headers)
if not self._streams_context: if not self._streams_context:
raise MCPConnectionError("Failed to create connection context") raise MCPConnectionError("Failed to create connection context")
# Use exit_stack to manage context managers properly # Use exit_stack to manage context managers properly
if method_name == "mcp": if method_name == "mcp":
read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context) read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context)
streams = (read_stream, write_stream) streams = (read_stream, write_stream)
else: # sse_client else: # sse_client
streams = self.exit_stack.enter_context(self._streams_context) streams = self._exit_stack.enter_context(self._streams_context)
self._session_context = ClientSession(*streams) self._session_context = ClientSession(*streams)
self._session = self.exit_stack.enter_context(self._session_context) self._session = self._exit_stack.enter_context(self._session_context)
session = cast(ClientSession, self._session) session = cast(ClientSession, self._session)
session.initialize() session.initialize()
return return
@@ -120,9 +131,6 @@ class MCPClient:
if first_try: if first_try:
return self.connect_server(client_factory, method_name, first_try=False) return self.connect_server(client_factory, method_name, first_try=False)
except MCPConnectionError:
raise
def list_tools(self) -> list[Tool]: def list_tools(self) -> list[Tool]:
"""Connect to an MCP server running with SSE transport""" """Connect to an MCP server running with SSE transport"""
# List available tools to verify connection # List available tools to verify connection
@@ -142,7 +150,7 @@ class MCPClient:
"""Clean up resources""" """Clean up resources"""
try: try:
# ExitStack will handle proper cleanup of all managed context managers # ExitStack will handle proper cleanup of all managed context managers
self.exit_stack.close() self._exit_stack.close()
except Exception as e: except Exception as e:
logging.exception("Error during cleanup") logging.exception("Error during cleanup")
raise ValueError(f"Error during cleanup: {e}") raise ValueError(f"Error during cleanup: {e}")

View File

@@ -2,7 +2,6 @@ import logging
import queue import queue
from collections.abc import Callable from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from contextlib import ExitStack
from datetime import timedelta from datetime import timedelta
from types import TracebackType from types import TracebackType
from typing import Any, Generic, Self, TypeVar from typing import Any, Generic, Self, TypeVar
@@ -170,7 +169,6 @@ class BaseSession(
self._receive_notification_type = receive_notification_type self._receive_notification_type = receive_notification_type
self._session_read_timeout_seconds = read_timeout_seconds self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {} self._in_flight = {}
self._exit_stack = ExitStack()
# Initialize executor and future to None for proper cleanup checks # Initialize executor and future to None for proper cleanup checks
self._executor: ThreadPoolExecutor | None = None self._executor: ThreadPoolExecutor | None = None
self._receiver_future: Future | None = None self._receiver_future: Future | None = None
@@ -377,7 +375,7 @@ class BaseSession(
self._handle_incoming(RuntimeError(f"Server Error: {message}")) self._handle_incoming(RuntimeError(f"Server Error: {message}"))
except queue.Empty: except queue.Empty:
continue continue
except Exception as e: except Exception:
logging.exception("Error in message processing loop") logging.exception("Error in message processing loop")
raise raise
@@ -389,14 +387,12 @@ class BaseSession(
If the request is responded to within this method, it will not be If the request is responded to within this method, it will not be
forwarded on to the message stream. forwarded on to the message stream.
""" """
pass
def _received_notification(self, notification: ReceiveNotificationT) -> None: def _received_notification(self, notification: ReceiveNotificationT) -> None:
""" """
Can be overridden by subclasses to handle a notification without needing Can be overridden by subclasses to handle a notification without needing
to listen on the message stream. to listen on the message stream.
""" """
pass
def send_progress_notification( def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None self, progress_token: str | int, progress: float, total: float | None = None
@@ -405,11 +401,9 @@ class BaseSession(
Sends a progress notification for a request that is currently being Sends a progress notification for a request that is currently being
processed. processed.
""" """
pass
def _handle_incoming( def _handle_incoming(
self, self,
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
) -> None: ) -> None:
"""A generic handler for incoming messages. Overwritten by subclasses.""" """A generic handler for incoming messages. Overwritten by subclasses."""
pass

View File

@@ -1,3 +1,4 @@
import queue
from datetime import timedelta from datetime import timedelta
from typing import Any, Protocol from typing import Any, Protocol
@@ -85,8 +86,8 @@ class ClientSession(
): ):
def __init__( def __init__(
self, self,
read_stream, read_stream: queue.Queue,
write_stream, write_stream: queue.Queue,
read_timeout_seconds: timedelta | None = None, read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None, sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None, list_roots_callback: ListRootsFnT | None = None,

View File

@@ -12,8 +12,6 @@ from core.tools.errors import ToolProviderCredentialValidationError
class ToolProviderController(ABC): class ToolProviderController(ABC):
entity: ToolProviderEntity
def __init__(self, entity: ToolProviderEntity) -> None: def __init__(self, entity: ToolProviderEntity) -> None:
self.entity = entity self.entity = entity

View File

@@ -1,5 +1,5 @@
import json import json
from typing import Any from typing import Any, Optional
from core.mcp.types import Tool as RemoteMCPTool from core.mcp.types import Tool as RemoteMCPTool
from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController
@@ -19,15 +19,24 @@ from services.tools.tools_transform_service import ToolTransformService
class MCPToolProviderController(ToolProviderController): class MCPToolProviderController(ToolProviderController):
provider_id: str def __init__(
entity: ToolProviderEntityWithPlugin self,
entity: ToolProviderEntityWithPlugin,
def __init__(self, entity: ToolProviderEntityWithPlugin, provider_id: str, tenant_id: str, server_url: str) -> None: provider_id: str,
tenant_id: str,
server_url: str,
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = None,
sse_read_timeout: Optional[float] = None,
) -> None:
super().__init__(entity) super().__init__(entity)
self.entity = entity self.entity: ToolProviderEntityWithPlugin = entity
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.provider_id = provider_id self.provider_id = provider_id
self.server_url = server_url self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
@property @property
def provider_type(self) -> ToolProviderType: def provider_type(self) -> ToolProviderType:
@@ -85,6 +94,9 @@ class MCPToolProviderController(ToolProviderController):
provider_id=db_provider.server_identifier or "", provider_id=db_provider.server_identifier or "",
tenant_id=db_provider.tenant_id or "", tenant_id=db_provider.tenant_id or "",
server_url=db_provider.decrypted_server_url, server_url=db_provider.decrypted_server_url,
headers={}, # TODO: get headers from db provider
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
) )
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
@@ -111,6 +123,9 @@ class MCPToolProviderController(ToolProviderController):
icon=self.entity.identity.icon, icon=self.entity.identity.icon,
server_url=self.server_url, server_url=self.server_url,
provider_id=self.provider_id, provider_id=self.provider_id,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
) )
def get_tools(self) -> list[MCPTool]: # type: ignore def get_tools(self) -> list[MCPTool]: # type: ignore
@@ -125,6 +140,9 @@ class MCPToolProviderController(ToolProviderController):
icon=self.entity.identity.icon, icon=self.entity.identity.icon,
server_url=self.server_url, server_url=self.server_url,
provider_id=self.provider_id, provider_id=self.provider_id,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
) )
for tool_entity in self.entity.tools for tool_entity in self.entity.tools
] ]

View File

@@ -13,13 +13,25 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too
class MCPTool(Tool): class MCPTool(Tool):
def __init__( def __init__(
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str self,
entity: ToolEntity,
runtime: ToolRuntime,
tenant_id: str,
icon: str,
server_url: str,
provider_id: str,
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = None,
sse_read_timeout: Optional[float] = None,
) -> None: ) -> None:
super().__init__(entity, runtime) super().__init__(entity, runtime)
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.icon = icon self.icon = icon
self.server_url = server_url self.server_url = server_url
self.provider_id = provider_id self.provider_id = provider_id
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
def tool_provider_type(self) -> ToolProviderType: def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.MCP return ToolProviderType.MCP
@@ -35,7 +47,15 @@ class MCPTool(Tool):
from core.tools.errors import ToolInvokeError from core.tools.errors import ToolInvokeError
try: try:
with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client: with MCPClient(
self.server_url,
self.provider_id,
self.tenant_id,
authed=True,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
) as mcp_client:
tool_parameters = self._handle_none_parameter(tool_parameters) tool_parameters = self._handle_none_parameter(tool_parameters)
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPAuthError as e: except MCPAuthError as e:
@@ -72,6 +92,9 @@ class MCPTool(Tool):
icon=self.icon, icon=self.icon,
server_url=self.server_url, server_url=self.server_url,
provider_id=self.provider_id, provider_id=self.provider_id,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
) )
def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]: def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:

View File

@@ -789,9 +789,6 @@ class ToolManager:
""" """
get api provider get api provider
""" """
"""
get tool provider
"""
provider_name = provider provider_name = provider
provider_obj: ApiToolProvider | None = ( provider_obj: ApiToolProvider | None = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)

View File

@@ -0,0 +1,47 @@
"""add timeout for tool_mcp_providers
Revision ID: fa8b0fa6f407
Revises: 532b3f888abf
Create Date: 2025-08-07 11:15:31.517985
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'fa8b0fa6f407'
down_revision = '532b3f888abf'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('timeout', sa.Float(), server_default=sa.text('30'), nullable=False))
batch_op.add_column(sa.Column('sse_read_timeout', sa.Float(), server_default=sa.text('300'), nullable=False))
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('workflow_node_execution_created_at_idx'))
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('workflow_run_created_at_idx'))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.create_index(batch_op.f('workflow_run_created_at_idx'), ['created_at'], unique=False)
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
batch_op.create_index(batch_op.f('workflow_node_execution_created_at_idx'), ['created_at'], unique=False)
with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
batch_op.drop_column('sse_read_timeout')
batch_op.drop_column('timeout')
# ### end Alembic commands ###

View File

@@ -278,6 +278,8 @@ class MCPToolProvider(Base):
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
) )
timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30"))
sse_read_timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("300"))
def load_user(self) -> Account | None: def load_user(self) -> Account | None:
return db.session.query(Account).where(Account.id == self.user_id).first() return db.session.query(Account).where(Account.id == self.user_id).first()

View File

@@ -59,6 +59,8 @@ class MCPToolManageService:
icon_type: str, icon_type: str,
icon_background: str, icon_background: str,
server_identifier: str, server_identifier: str,
timeout: float,
sse_read_timeout: float,
) -> ToolProviderApiEntity: ) -> ToolProviderApiEntity:
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
existing_provider = ( existing_provider = (
@@ -91,6 +93,8 @@ class MCPToolManageService:
tools="[]", tools="[]",
icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon, icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon,
server_identifier=server_identifier, server_identifier=server_identifier,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) )
db.session.add(mcp_tool) db.session.add(mcp_tool)
db.session.commit() db.session.commit()
@@ -166,6 +170,8 @@ class MCPToolManageService:
icon_type: str, icon_type: str,
icon_background: str, icon_background: str,
server_identifier: str, server_identifier: str,
timeout: float | None = None,
sse_read_timeout: float | None = None,
): ):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
@@ -197,6 +203,10 @@ class MCPToolManageService:
mcp_provider.tools = reconnect_result["tools"] mcp_provider.tools = reconnect_result["tools"]
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"] mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
if timeout is not None:
mcp_provider.timeout = timeout
if sse_read_timeout is not None:
mcp_provider.sse_read_timeout = sse_read_timeout
db.session.commit() db.session.commit()
except IntegrityError as e: except IntegrityError as e:
db.session.rollback() db.session.rollback()