feat: add MCP support (#20716)

Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
Novice
2025-07-10 14:01:34 +08:00
committed by GitHub
parent 18b58424ec
commit 535fff62f3
54 changed files with 6634 additions and 154 deletions

View File

@@ -0,0 +1,342 @@
import base64
import hashlib
import json
import os
import secrets
import urllib.parse
from typing import Optional
from urllib.parse import urljoin
import requests
from pydantic import BaseModel, ValidationError
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.types import (
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthTokens,
)
from extensions.ext_redis import redis_client
LATEST_PROTOCOL_VERSION = "1.0"
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
class OAuthCallbackState(BaseModel):
provider_id: str
tenant_id: str
server_url: str
metadata: OAuthMetadata | None = None
client_information: OAuthClientInformation
code_verifier: str
redirect_uri: str
def generate_pkce_challenge() -> tuple[str, str]:
"""Generate PKCE challenge and verifier."""
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
code_verifier = code_verifier.replace("=", "").replace("+", "-").replace("/", "_")
code_challenge_hash = hashlib.sha256(code_verifier.encode("utf-8")).digest()
code_challenge = base64.urlsafe_b64encode(code_challenge_hash).decode("utf-8")
code_challenge = code_challenge.replace("=", "").replace("+", "-").replace("/", "_")
return code_verifier, code_challenge
def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
"""Create a secure state parameter by storing state data in Redis and returning a random state key."""
# Generate a secure random state key
state_key = secrets.token_urlsafe(32)
# Store the state data in Redis with expiration
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
redis_client.setex(redis_key, OAUTH_STATE_EXPIRY_SECONDS, state_data.model_dump_json())
return state_key
def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
"""Retrieve and decode OAuth state data from Redis using the state key, then delete it."""
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
# Get state data from Redis
state_data = redis_client.get(redis_key)
if not state_data:
raise ValueError("State parameter has expired or does not exist")
# Delete the state data from Redis immediately after retrieval to prevent reuse
redis_client.delete(redis_key)
try:
# Parse and validate the state data
oauth_state = OAuthCallbackState.model_validate_json(state_data)
return oauth_state
except ValidationError as e:
raise ValueError(f"Invalid state parameter: {str(e)}")
def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
"""Handle the callback from the OAuth provider."""
# Retrieve state data from Redis (state is automatically deleted after retrieval)
full_state_data = _retrieve_redis_state(state_key)
tokens = exchange_authorization(
full_state_data.server_url,
full_state_data.metadata,
full_state_data.client_information,
authorization_code,
full_state_data.code_verifier,
full_state_data.redirect_uri,
)
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
provider.save_tokens(tokens)
return full_state_data
def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]:
"""Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
try:
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
response = requests.get(url, headers=headers)
if response.status_code == 404:
return None
if not response.ok:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
except requests.RequestException as e:
if isinstance(e, requests.ConnectionError):
response = requests.get(url)
if response.status_code == 404:
return None
if not response.ok:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
raise
def start_authorization(
server_url: str,
metadata: Optional[OAuthMetadata],
client_information: OAuthClientInformation,
redirect_url: str,
provider_id: str,
tenant_id: str,
) -> tuple[str, str]:
"""Begins the authorization flow with secure Redis state storage."""
response_type = "code"
code_challenge_method = "S256"
if metadata:
authorization_url = metadata.authorization_endpoint
if response_type not in metadata.response_types_supported:
raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
if (
not metadata.code_challenge_methods_supported
or code_challenge_method not in metadata.code_challenge_methods_supported
):
raise ValueError(
f"Incompatible auth server: does not support code challenge method {code_challenge_method}"
)
else:
authorization_url = urljoin(server_url, "/authorize")
code_verifier, code_challenge = generate_pkce_challenge()
# Prepare state data with all necessary information
state_data = OAuthCallbackState(
provider_id=provider_id,
tenant_id=tenant_id,
server_url=server_url,
metadata=metadata,
client_information=client_information,
code_verifier=code_verifier,
redirect_uri=redirect_url,
)
# Store state data in Redis and generate secure state key
state_key = _create_secure_redis_state(state_data)
params = {
"response_type": response_type,
"client_id": client_information.client_id,
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"redirect_uri": redirect_url,
"state": state_key,
}
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
return authorization_url, code_verifier
def exchange_authorization(
server_url: str,
metadata: Optional[OAuthMetadata],
client_information: OAuthClientInformation,
authorization_code: str,
code_verifier: str,
redirect_uri: str,
) -> OAuthTokens:
"""Exchanges an authorization code for an access token."""
grant_type = "authorization_code"
if metadata:
token_url = metadata.token_endpoint
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
else:
token_url = urljoin(server_url, "/token")
params = {
"grant_type": grant_type,
"client_id": client_information.client_id,
"code": authorization_code,
"code_verifier": code_verifier,
"redirect_uri": redirect_uri,
}
if client_information.client_secret:
params["client_secret"] = client_information.client_secret
response = requests.post(token_url, data=params)
if not response.ok:
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
return OAuthTokens.model_validate(response.json())
def refresh_authorization(
server_url: str,
metadata: Optional[OAuthMetadata],
client_information: OAuthClientInformation,
refresh_token: str,
) -> OAuthTokens:
"""Exchange a refresh token for an updated access token."""
grant_type = "refresh_token"
if metadata:
token_url = metadata.token_endpoint
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
else:
token_url = urljoin(server_url, "/token")
params = {
"grant_type": grant_type,
"client_id": client_information.client_id,
"refresh_token": refresh_token,
}
if client_information.client_secret:
params["client_secret"] = client_information.client_secret
response = requests.post(token_url, data=params)
if not response.ok:
raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
return OAuthTokens.parse_obj(response.json())
def register_client(
server_url: str,
metadata: Optional[OAuthMetadata],
client_metadata: OAuthClientMetadata,
) -> OAuthClientInformationFull:
"""Performs OAuth 2.0 Dynamic Client Registration."""
if metadata:
if not metadata.registration_endpoint:
raise ValueError("Incompatible auth server: does not support dynamic client registration")
registration_url = metadata.registration_endpoint
else:
registration_url = urljoin(server_url, "/register")
response = requests.post(
registration_url,
json=client_metadata.model_dump(),
headers={"Content-Type": "application/json"},
)
if not response.ok:
response.raise_for_status()
return OAuthClientInformationFull.model_validate(response.json())
def auth(
provider: OAuthClientProvider,
server_url: str,
authorization_code: Optional[str] = None,
state_param: Optional[str] = None,
for_list: bool = False,
) -> dict[str, str]:
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
metadata = discover_oauth_metadata(server_url)
# Handle client registration if needed
client_information = provider.client_information()
if not client_information:
if authorization_code is not None:
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
try:
full_information = register_client(server_url, metadata, provider.client_metadata)
except requests.RequestException as e:
raise ValueError(f"Could not register OAuth client: {e}")
provider.save_client_information(full_information)
client_information = full_information
# Exchange authorization code for tokens
if authorization_code is not None:
if not state_param:
raise ValueError("State parameter is required when exchanging authorization code")
try:
# Retrieve state data from Redis using state key
full_state_data = _retrieve_redis_state(state_param)
code_verifier = full_state_data.code_verifier
redirect_uri = full_state_data.redirect_uri
if not code_verifier or not redirect_uri:
raise ValueError("Missing code_verifier or redirect_uri in state data")
except (json.JSONDecodeError, ValueError) as e:
raise ValueError(f"Invalid state parameter: {e}")
tokens = exchange_authorization(
server_url,
metadata,
client_information,
authorization_code,
code_verifier,
redirect_uri,
)
provider.save_tokens(tokens)
return {"result": "success"}
provider_tokens = provider.tokens()
# Handle token refresh or new authorization
if provider_tokens and provider_tokens.refresh_token:
try:
new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token)
provider.save_tokens(new_tokens)
return {"result": "success"}
except Exception as e:
raise ValueError(f"Could not refresh OAuth tokens: {e}")
# Start new authorization flow
authorization_url, code_verifier = start_authorization(
server_url,
metadata,
client_information,
provider.redirect_url,
provider.mcp_provider.id,
provider.mcp_provider.tenant_id,
)
provider.save_code_verifier(code_verifier)
return {"authorization_url": authorization_url}