From 658157e9a1fded47732688083fbdefadeafba784 Mon Sep 17 00:00:00 2001 From: Will Date: Fri, 15 Aug 2025 15:19:30 +0800 Subject: [PATCH] chore: improved type annotations in MCP-related codes (#23984) --- api/core/mcp/client/sse_client.py | 7 +++-- api/core/mcp/server/streamable_http.py | 7 +++-- api/core/mcp/utils.py | 42 +++++++++++++++++++++----- 3 files changed, 43 insertions(+), 13 deletions(-) diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py index 2d3a3f534..c6fe768a6 100644 --- a/api/core/mcp/client/sse_client.py +++ b/api/core/mcp/client/sse_client.py @@ -7,6 +7,7 @@ from typing import Any, TypeAlias, final from urllib.parse import urljoin, urlparse import httpx +from httpx_sse import EventSource, ServerSentEvent from sseclient import SSEClient from core.mcp import types @@ -114,7 +115,7 @@ class SSETransport: logger.exception("Error parsing server message") read_queue.put(exc) - def _handle_sse_event(self, sse, read_queue: ReadQueue, status_queue: StatusQueue) -> None: + def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue) -> None: """Handle a single SSE event. Args: @@ -130,7 +131,7 @@ class SSETransport: case _: logger.warning("Unknown SSE event: %s", sse.event) - def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None: + def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue) -> None: """Read and process SSE events. Args: @@ -225,7 +226,7 @@ class SSETransport: self, executor: ThreadPoolExecutor, client: httpx.Client, - event_source, + event_source: EventSource, ) -> tuple[ReadQueue, WriteQueue]: """Establish connection and start worker threads. diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 496b5432a..efe91bbff 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -16,13 +16,14 @@ from extensions.ext_database import db from models.model import App, AppMCPServer, AppMode, EndUser from services.app_generate_service import AppGenerateService -""" -Apply to MCP HTTP streamable server with stateless http -""" logger = logging.getLogger(__name__) class MCPServerStreamableHTTPRequestHandler: + """ + Apply to MCP HTTP streamable server with stateless http + """ + def __init__( self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity] ): diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py index a54badcd4..80912bc4c 100644 --- a/api/core/mcp/utils.py +++ b/api/core/mcp/utils.py @@ -1,6 +1,10 @@ import json +from collections.abc import Generator +from contextlib import AbstractContextManager import httpx +import httpx_sse +from httpx_sse import connect_sse from configs import dify_config from core.mcp.types import ErrorData, JSONRPCError @@ -55,20 +59,42 @@ def create_ssrf_proxy_mcp_http_client( ) -def ssrf_proxy_sse_connect(url, **kwargs): +def ssrf_proxy_sse_connect(url: str, **kwargs) -> AbstractContextManager[httpx_sse.EventSource]: """Connect to SSE endpoint with SSRF proxy protection. This function creates an SSE connection using the configured proxy settings - to prevent SSRF attacks when connecting to external endpoints. + to prevent SSRF attacks when connecting to external endpoints. It returns + a context manager that yields an EventSource object for SSE streaming. + + The function handles HTTP client creation and cleanup automatically, but + also accepts a pre-configured client via kwargs. Args: - url: The SSE endpoint URL - **kwargs: Additional arguments passed to the SSE connection + url (str): The SSE endpoint URL to connect to + **kwargs: Additional arguments passed to the SSE connection, including: + - client (httpx.Client, optional): Pre-configured HTTP client. + If not provided, one will be created with SSRF protection. + - method (str, optional): HTTP method to use, defaults to "GET" + - headers (dict, optional): HTTP headers to include in the request + - timeout (httpx.Timeout, optional): Timeout configuration for the connection Returns: - EventSource object for SSE streaming + AbstractContextManager[httpx_sse.EventSource]: A context manager that yields an EventSource + object for SSE streaming. The EventSource provides access to server-sent events. + + Example: + ```python + with ssrf_proxy_sse_connect(url, headers=headers) as event_source: + for sse in event_source.iter_sse(): + print(sse.event, sse.data) + ``` + + Note: + If a client is not provided in kwargs, one will be automatically created + with SSRF protection based on the application's configuration. If an + exception occurs during connection, any automatically created client + will be cleaned up automatically. """ - from httpx_sse import connect_sse # Extract client if provided, otherwise create one client = kwargs.pop("client", None) @@ -101,7 +127,9 @@ def ssrf_proxy_sse_connect(url, **kwargs): raise -def create_mcp_error_response(request_id: int | str | None, code: int, message: str, data=None): +def create_mcp_error_response( + request_id: int | str | None, code: int, message: str, data=None +) -> Generator[bytes, None, None]: """Create MCP error response""" error_data = ErrorData(code=code, message=message, data=data) json_response = JSONRPCError(