diff --git a/api/core/plugin/impl/oauth.py b/api/core/plugin/impl/oauth.py index 91774984c..b006bf1d4 100644 --- a/api/core/plugin/impl/oauth.py +++ b/api/core/plugin/impl/oauth.py @@ -1,3 +1,4 @@ +import binascii from collections.abc import Mapping from typing import Any @@ -16,7 +17,7 @@ class OAuthHandler(BasePluginClient): provider: str, system_credentials: Mapping[str, Any], ) -> PluginOAuthAuthorizationUrlResponse: - return self._request_with_plugin_daemon_response( + response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url", PluginOAuthAuthorizationUrlResponse, @@ -32,6 +33,9 @@ class OAuthHandler(BasePluginClient): "Content-Type": "application/json", }, ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for authorization URL request.") def get_credentials( self, @@ -49,7 +53,7 @@ class OAuthHandler(BasePluginClient): # encode request to raw http request raw_request_bytes = self._convert_request_to_raw_data(request) - return self._request_with_plugin_daemon_response( + response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/oauth/get_credentials", PluginOAuthCredentialsResponse, @@ -58,7 +62,8 @@ class OAuthHandler(BasePluginClient): "data": { "provider": provider, "system_credentials": system_credentials, - "raw_request_bytes": raw_request_bytes, + # for json serialization + "raw_http_request": binascii.hexlify(raw_request_bytes).decode(), }, }, headers={ @@ -66,6 +71,9 @@ class OAuthHandler(BasePluginClient): "Content-Type": "application/json", }, ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for authorization URL request.") def _convert_request_to_raw_data(self, request: Request) -> bytes: """ @@ -79,7 +87,7 @@ class OAuthHandler(BasePluginClient): """ # Start with the request line method = request.method - path = request.path + path = request.full_path protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1") raw_data = f"{method} {path} {protocol}\r\n".encode() diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py index 461247419..4077ec38d 100644 --- a/api/services/plugin/oauth_service.py +++ b/api/services/plugin/oauth_service.py @@ -1,7 +1,61 @@ +import json +import uuid + from core.plugin.impl.base import BasePluginClient +from extensions.ext_redis import redis_client -class OAuthService(BasePluginClient): - @classmethod - def get_authorization_url(cls, tenant_id: str, user_id: str, provider_name: str) -> str: - return "1234567890" +class OAuthProxyService(BasePluginClient): + # Default max age for proxy context parameter in seconds + __MAX_AGE__ = 5 * 60 # 5 minutes + + @staticmethod + def create_proxy_context(user_id, tenant_id, plugin_id, provider): + """ + Create a proxy context for an OAuth 2.0 authorization request. + + This parameter is a crucial security measure to prevent Cross-Site Request + Forgery (CSRF) attacks. It works by generating a unique nonce and storing it + in a distributed cache (Redis) along with the user's session context. + + The returned nonce should be included as the 'proxy_context' parameter in the + authorization URL. Upon callback, the `use_proxy_context` method + is used to verify the state, ensuring the request's integrity and authenticity, + and mitigating replay attacks. + """ + seconds, _ = redis_client.time() + context_id = str(uuid.uuid4()) + data = { + "user_id": user_id, + "plugin_id": plugin_id, + "tenant_id": tenant_id, + "provider": provider, + # encode redis time to avoid distribution time skew + "timestamp": seconds, + } + # ignore nonce collision + redis_client.setex( + f"oauth_proxy_context:{context_id}", + OAuthProxyService.__MAX_AGE__, + json.dumps(data), + ) + return context_id + + @staticmethod + def use_proxy_context(context_id, max_age=__MAX_AGE__): + """ + Validate the proxy context parameter. + This checks if the context_id is valid and not expired. + """ + if not context_id: + raise ValueError("context_id is required") + # get data from redis + data = redis_client.getdel(f"oauth_proxy_context:{context_id}") + if not data: + raise ValueError("context_id is invalid") + # check if data is expired + seconds, _ = redis_client.time() + state = json.loads(data) + if state.get("timestamp") < seconds - max_age: + raise ValueError("context_id is expired") + return state diff --git a/api/tests/unit_tests/utils/http_parser/test_oauth_convert_request_to_raw_data.py b/api/tests/unit_tests/utils/http_parser/test_oauth_convert_request_to_raw_data.py index f788a9756..293ac253f 100644 --- a/api/tests/unit_tests/utils/http_parser/test_oauth_convert_request_to_raw_data.py +++ b/api/tests/unit_tests/utils/http_parser/test_oauth_convert_request_to_raw_data.py @@ -1,3 +1,5 @@ +import json + from werkzeug import Request from werkzeug.datastructures import Headers from werkzeug.test import EnvironBuilder @@ -15,6 +17,59 @@ def test_oauth_convert_request_to_raw_data(): request = Request(builder.get_environ()) raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) - assert b"GET /test HTTP/1.1" in raw_request_bytes + assert b"GET /test? HTTP/1.1" in raw_request_bytes assert b"Content-Type: application/json" in raw_request_bytes assert b"\r\n\r\n" in raw_request_bytes + + +def test_oauth_convert_request_to_raw_data_with_query_params(): + oauth_handler = OAuthHandler() + builder = EnvironBuilder( + method="GET", + path="/test", + query_string="code=abc123&state=xyz789", + headers=Headers({"Content-Type": "application/json"}), + ) + request = Request(builder.get_environ()) + raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) + + assert b"GET /test?code=abc123&state=xyz789 HTTP/1.1" in raw_request_bytes + assert b"Content-Type: application/json" in raw_request_bytes + assert b"\r\n\r\n" in raw_request_bytes + + +def test_oauth_convert_request_to_raw_data_with_post_body(): + oauth_handler = OAuthHandler() + builder = EnvironBuilder( + method="POST", + path="/test", + data="param1=value1¶m2=value2", + headers=Headers({"Content-Type": "application/x-www-form-urlencoded"}), + ) + request = Request(builder.get_environ()) + raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) + + assert b"POST /test? HTTP/1.1" in raw_request_bytes + assert b"Content-Type: application/x-www-form-urlencoded" in raw_request_bytes + assert b"\r\n\r\n" in raw_request_bytes + assert b"param1=value1¶m2=value2" in raw_request_bytes + + +def test_oauth_convert_request_to_raw_data_with_json_body(): + oauth_handler = OAuthHandler() + json_data = {"code": "abc123", "state": "xyz789", "grant_type": "authorization_code"} + builder = EnvironBuilder( + method="POST", + path="/test", + data=json.dumps(json_data), + headers=Headers({"Content-Type": "application/json"}), + ) + request = Request(builder.get_environ()) + raw_request_bytes = oauth_handler._convert_request_to_raw_data(request) + + assert b"POST /test? HTTP/1.1" in raw_request_bytes + assert b"Content-Type: application/json" in raw_request_bytes + assert b"\r\n\r\n" in raw_request_bytes + assert b'"code": "abc123"' in raw_request_bytes + assert b'"state": "xyz789"' in raw_request_bytes + assert b'"grant_type": "authorization_code"' in raw_request_bytes