Files
dify/api/core/mcp/utils.py

143 lines
5.1 KiB
Python

import json
from collections.abc import Generator
from contextlib import AbstractContextManager
import httpx
import httpx_sse
from httpx_sse import connect_sse
from configs import dify_config
from core.mcp.types import ErrorData, JSONRPCError
from core.model_runtime.utils.encoders import jsonable_encoder
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
STATUS_FORCELIST = [429, 500, 502, 503, 504]
def create_ssrf_proxy_mcp_http_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
) -> httpx.Client:
"""Create an HTTPX client with SSRF proxy configuration for MCP connections.
Args:
headers: Optional headers to include in the client
timeout: Optional timeout configuration
Returns:
Configured httpx.Client with proxy settings
"""
if dify_config.SSRF_PROXY_ALL_URL:
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
proxy=dify_config.SSRF_PROXY_ALL_URL,
)
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxy_mounts = {
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY),
"https://": httpx.HTTPTransport(
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
),
}
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
mounts=proxy_mounts,
)
else:
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
)
def ssrf_proxy_sse_connect(url: str, **kwargs) -> AbstractContextManager[httpx_sse.EventSource]:
"""Connect to SSE endpoint with SSRF proxy protection.
This function creates an SSE connection using the configured proxy settings
to prevent SSRF attacks when connecting to external endpoints. It returns
a context manager that yields an EventSource object for SSE streaming.
The function handles HTTP client creation and cleanup automatically, but
also accepts a pre-configured client via kwargs.
Args:
url (str): The SSE endpoint URL to connect to
**kwargs: Additional arguments passed to the SSE connection, including:
- client (httpx.Client, optional): Pre-configured HTTP client.
If not provided, one will be created with SSRF protection.
- method (str, optional): HTTP method to use, defaults to "GET"
- headers (dict, optional): HTTP headers to include in the request
- timeout (httpx.Timeout, optional): Timeout configuration for the connection
Returns:
AbstractContextManager[httpx_sse.EventSource]: A context manager that yields an EventSource
object for SSE streaming. The EventSource provides access to server-sent events.
Example:
```python
with ssrf_proxy_sse_connect(url, headers=headers) as event_source:
for sse in event_source.iter_sse():
print(sse.event, sse.data)
```
Note:
If a client is not provided in kwargs, one will be automatically created
with SSRF protection based on the application's configuration. If an
exception occurs during connection, any automatically created client
will be cleaned up automatically.
"""
# Extract client if provided, otherwise create one
client = kwargs.pop("client", None)
if client is None:
# Create client with SSRF proxy configuration
timeout = kwargs.pop(
"timeout",
httpx.Timeout(
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
),
)
headers = kwargs.pop("headers", {})
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
client_provided = False
else:
client_provided = True
# Extract method if provided, default to GET
method = kwargs.pop("method", "GET")
try:
return connect_sse(client, method, url, **kwargs)
except Exception:
# If we created the client, we need to clean it up on error
if not client_provided:
client.close()
raise
def create_mcp_error_response(
request_id: int | str | None, code: int, message: str, data=None
) -> Generator[bytes, None, None]:
"""Create MCP error response"""
error_data = ErrorData(code=code, message=message, data=data)
json_response = JSONRPCError(
jsonrpc="2.0",
id=request_id or 1,
error=error_data,
)
json_data = json.dumps(jsonable_encoder(json_response))
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
yield sse_content