diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index 3456770a2..eb783297c 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -5,7 +5,7 @@ import os import secrets import urllib.parse from typing import Optional -from urllib.parse import urljoin +from urllib.parse import urljoin, urlparse import httpx from pydantic import BaseModel, ValidationError @@ -99,9 +99,37 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta return full_state_data +def check_support_resource_discovery(server_url: str) -> tuple[bool, str]: + """Check if the server supports OAuth 2.0 Resource Discovery.""" + b_scheme, b_netloc, b_path, b_params, b_query, b_fragment = urlparse(server_url, "", True) + url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}" + if b_query: + url_for_resource_discovery += f"?{b_query}" + if b_fragment: + url_for_resource_discovery += f"#{b_fragment}" + try: + headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"} + response = httpx.get(url_for_resource_discovery, headers=headers) + if 200 <= response.status_code < 300: + body = response.json() + if "authorization_server_url" in body: + return True, body["authorization_server_url"][0] + else: + return False, "" + return False, "" + except httpx.RequestError as e: + # Not support resource discovery, fall back to well-known OAuth metadata + return False, "" + + def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]: """Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata.""" - url = urljoin(server_url, "/.well-known/oauth-authorization-server") + # First check if the server supports OAuth 2.0 Resource Discovery + support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url) + if support_resource_discovery: + url = oauth_discovery_url + else: + url = urljoin(server_url, "/.well-known/oauth-authorization-server") try: headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}