chore: improved type annotations in MCP-related codes (#23984)
This commit is contained in:
@@ -7,6 +7,7 @@ from typing import Any, TypeAlias, final
|
|||||||
from urllib.parse import urljoin, urlparse
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from httpx_sse import EventSource, ServerSentEvent
|
||||||
from sseclient import SSEClient
|
from sseclient import SSEClient
|
||||||
|
|
||||||
from core.mcp import types
|
from core.mcp import types
|
||||||
@@ -114,7 +115,7 @@ class SSETransport:
|
|||||||
logger.exception("Error parsing server message")
|
logger.exception("Error parsing server message")
|
||||||
read_queue.put(exc)
|
read_queue.put(exc)
|
||||||
|
|
||||||
def _handle_sse_event(self, sse, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
||||||
"""Handle a single SSE event.
|
"""Handle a single SSE event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -130,7 +131,7 @@ class SSETransport:
|
|||||||
case _:
|
case _:
|
||||||
logger.warning("Unknown SSE event: %s", sse.event)
|
logger.warning("Unknown SSE event: %s", sse.event)
|
||||||
|
|
||||||
def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
||||||
"""Read and process SSE events.
|
"""Read and process SSE events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -225,7 +226,7 @@ class SSETransport:
|
|||||||
self,
|
self,
|
||||||
executor: ThreadPoolExecutor,
|
executor: ThreadPoolExecutor,
|
||||||
client: httpx.Client,
|
client: httpx.Client,
|
||||||
event_source,
|
event_source: EventSource,
|
||||||
) -> tuple[ReadQueue, WriteQueue]:
|
) -> tuple[ReadQueue, WriteQueue]:
|
||||||
"""Establish connection and start worker threads.
|
"""Establish connection and start worker threads.
|
||||||
|
|
||||||
|
@@ -16,13 +16,14 @@ from extensions.ext_database import db
|
|||||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
|
|
||||||
"""
|
|
||||||
Apply to MCP HTTP streamable server with stateless http
|
|
||||||
"""
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MCPServerStreamableHTTPRequestHandler:
|
class MCPServerStreamableHTTPRequestHandler:
|
||||||
|
"""
|
||||||
|
Apply to MCP HTTP streamable server with stateless http
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
|
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
|
||||||
):
|
):
|
||||||
|
@@ -1,6 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
|
from collections.abc import Generator
|
||||||
|
from contextlib import AbstractContextManager
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
import httpx_sse
|
||||||
|
from httpx_sse import connect_sse
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.mcp.types import ErrorData, JSONRPCError
|
from core.mcp.types import ErrorData, JSONRPCError
|
||||||
@@ -55,20 +59,42 @@ def create_ssrf_proxy_mcp_http_client(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def ssrf_proxy_sse_connect(url, **kwargs):
|
def ssrf_proxy_sse_connect(url: str, **kwargs) -> AbstractContextManager[httpx_sse.EventSource]:
|
||||||
"""Connect to SSE endpoint with SSRF proxy protection.
|
"""Connect to SSE endpoint with SSRF proxy protection.
|
||||||
|
|
||||||
This function creates an SSE connection using the configured proxy settings
|
This function creates an SSE connection using the configured proxy settings
|
||||||
to prevent SSRF attacks when connecting to external endpoints.
|
to prevent SSRF attacks when connecting to external endpoints. It returns
|
||||||
|
a context manager that yields an EventSource object for SSE streaming.
|
||||||
|
|
||||||
|
The function handles HTTP client creation and cleanup automatically, but
|
||||||
|
also accepts a pre-configured client via kwargs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
url: The SSE endpoint URL
|
url (str): The SSE endpoint URL to connect to
|
||||||
**kwargs: Additional arguments passed to the SSE connection
|
**kwargs: Additional arguments passed to the SSE connection, including:
|
||||||
|
- client (httpx.Client, optional): Pre-configured HTTP client.
|
||||||
|
If not provided, one will be created with SSRF protection.
|
||||||
|
- method (str, optional): HTTP method to use, defaults to "GET"
|
||||||
|
- headers (dict, optional): HTTP headers to include in the request
|
||||||
|
- timeout (httpx.Timeout, optional): Timeout configuration for the connection
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
EventSource object for SSE streaming
|
AbstractContextManager[httpx_sse.EventSource]: A context manager that yields an EventSource
|
||||||
|
object for SSE streaming. The EventSource provides access to server-sent events.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
with ssrf_proxy_sse_connect(url, headers=headers) as event_source:
|
||||||
|
for sse in event_source.iter_sse():
|
||||||
|
print(sse.event, sse.data)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If a client is not provided in kwargs, one will be automatically created
|
||||||
|
with SSRF protection based on the application's configuration. If an
|
||||||
|
exception occurs during connection, any automatically created client
|
||||||
|
will be cleaned up automatically.
|
||||||
"""
|
"""
|
||||||
from httpx_sse import connect_sse
|
|
||||||
|
|
||||||
# Extract client if provided, otherwise create one
|
# Extract client if provided, otherwise create one
|
||||||
client = kwargs.pop("client", None)
|
client = kwargs.pop("client", None)
|
||||||
@@ -101,7 +127,9 @@ def ssrf_proxy_sse_connect(url, **kwargs):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def create_mcp_error_response(request_id: int | str | None, code: int, message: str, data=None):
|
def create_mcp_error_response(
|
||||||
|
request_id: int | str | None, code: int, message: str, data=None
|
||||||
|
) -> Generator[bytes, None, None]:
|
||||||
"""Create MCP error response"""
|
"""Create MCP error response"""
|
||||||
error_data = ErrorData(code=code, message=message, data=data)
|
error_data = ErrorData(code=code, message=message, data=data)
|
||||||
json_response = JSONRPCError(
|
json_response = JSONRPCError(
|
||||||
|
Reference in New Issue
Block a user