feat(oauth): plugin oauth service (#21480)
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import binascii
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -16,7 +17,7 @@ class OAuthHandler(BasePluginClient):
|
|||||||
provider: str,
|
provider: str,
|
||||||
system_credentials: Mapping[str, Any],
|
system_credentials: Mapping[str, Any],
|
||||||
) -> PluginOAuthAuthorizationUrlResponse:
|
) -> PluginOAuthAuthorizationUrlResponse:
|
||||||
return self._request_with_plugin_daemon_response(
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
"POST",
|
"POST",
|
||||||
f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
|
f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
|
||||||
PluginOAuthAuthorizationUrlResponse,
|
PluginOAuthAuthorizationUrlResponse,
|
||||||
@@ -32,6 +33,9 @@ class OAuthHandler(BasePluginClient):
|
|||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
for resp in response:
|
||||||
|
return resp
|
||||||
|
raise ValueError("No response received from plugin daemon for authorization URL request.")
|
||||||
|
|
||||||
def get_credentials(
|
def get_credentials(
|
||||||
self,
|
self,
|
||||||
@@ -49,7 +53,7 @@ class OAuthHandler(BasePluginClient):
|
|||||||
# encode request to raw http request
|
# encode request to raw http request
|
||||||
raw_request_bytes = self._convert_request_to_raw_data(request)
|
raw_request_bytes = self._convert_request_to_raw_data(request)
|
||||||
|
|
||||||
return self._request_with_plugin_daemon_response(
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
"POST",
|
"POST",
|
||||||
f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
|
f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
|
||||||
PluginOAuthCredentialsResponse,
|
PluginOAuthCredentialsResponse,
|
||||||
@@ -58,7 +62,8 @@ class OAuthHandler(BasePluginClient):
|
|||||||
"data": {
|
"data": {
|
||||||
"provider": provider,
|
"provider": provider,
|
||||||
"system_credentials": system_credentials,
|
"system_credentials": system_credentials,
|
||||||
"raw_request_bytes": raw_request_bytes,
|
# for json serialization
|
||||||
|
"raw_http_request": binascii.hexlify(raw_request_bytes).decode(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
headers={
|
headers={
|
||||||
@@ -66,6 +71,9 @@ class OAuthHandler(BasePluginClient):
|
|||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
for resp in response:
|
||||||
|
return resp
|
||||||
|
raise ValueError("No response received from plugin daemon for authorization URL request.")
|
||||||
|
|
||||||
def _convert_request_to_raw_data(self, request: Request) -> bytes:
|
def _convert_request_to_raw_data(self, request: Request) -> bytes:
|
||||||
"""
|
"""
|
||||||
@@ -79,7 +87,7 @@ class OAuthHandler(BasePluginClient):
|
|||||||
"""
|
"""
|
||||||
# Start with the request line
|
# Start with the request line
|
||||||
method = request.method
|
method = request.method
|
||||||
path = request.path
|
path = request.full_path
|
||||||
protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1")
|
protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1")
|
||||||
raw_data = f"{method} {path} {protocol}\r\n".encode()
|
raw_data = f"{method} {path} {protocol}\r\n".encode()
|
||||||
|
|
||||||
|
@@ -1,7 +1,61 @@
|
|||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
|
||||||
from core.plugin.impl.base import BasePluginClient
|
from core.plugin.impl.base import BasePluginClient
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
|
||||||
|
|
||||||
class OAuthService(BasePluginClient):
|
class OAuthProxyService(BasePluginClient):
|
||||||
@classmethod
|
# Default max age for proxy context parameter in seconds
|
||||||
def get_authorization_url(cls, tenant_id: str, user_id: str, provider_name: str) -> str:
|
__MAX_AGE__ = 5 * 60 # 5 minutes
|
||||||
return "1234567890"
|
|
||||||
|
@staticmethod
|
||||||
|
def create_proxy_context(user_id, tenant_id, plugin_id, provider):
|
||||||
|
"""
|
||||||
|
Create a proxy context for an OAuth 2.0 authorization request.
|
||||||
|
|
||||||
|
This parameter is a crucial security measure to prevent Cross-Site Request
|
||||||
|
Forgery (CSRF) attacks. It works by generating a unique nonce and storing it
|
||||||
|
in a distributed cache (Redis) along with the user's session context.
|
||||||
|
|
||||||
|
The returned nonce should be included as the 'proxy_context' parameter in the
|
||||||
|
authorization URL. Upon callback, the `use_proxy_context` method
|
||||||
|
is used to verify the state, ensuring the request's integrity and authenticity,
|
||||||
|
and mitigating replay attacks.
|
||||||
|
"""
|
||||||
|
seconds, _ = redis_client.time()
|
||||||
|
context_id = str(uuid.uuid4())
|
||||||
|
data = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"plugin_id": plugin_id,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"provider": provider,
|
||||||
|
# encode redis time to avoid distribution time skew
|
||||||
|
"timestamp": seconds,
|
||||||
|
}
|
||||||
|
# ignore nonce collision
|
||||||
|
redis_client.setex(
|
||||||
|
f"oauth_proxy_context:{context_id}",
|
||||||
|
OAuthProxyService.__MAX_AGE__,
|
||||||
|
json.dumps(data),
|
||||||
|
)
|
||||||
|
return context_id
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def use_proxy_context(context_id, max_age=__MAX_AGE__):
|
||||||
|
"""
|
||||||
|
Validate the proxy context parameter.
|
||||||
|
This checks if the context_id is valid and not expired.
|
||||||
|
"""
|
||||||
|
if not context_id:
|
||||||
|
raise ValueError("context_id is required")
|
||||||
|
# get data from redis
|
||||||
|
data = redis_client.getdel(f"oauth_proxy_context:{context_id}")
|
||||||
|
if not data:
|
||||||
|
raise ValueError("context_id is invalid")
|
||||||
|
# check if data is expired
|
||||||
|
seconds, _ = redis_client.time()
|
||||||
|
state = json.loads(data)
|
||||||
|
if state.get("timestamp") < seconds - max_age:
|
||||||
|
raise ValueError("context_id is expired")
|
||||||
|
return state
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
from werkzeug import Request
|
from werkzeug import Request
|
||||||
from werkzeug.datastructures import Headers
|
from werkzeug.datastructures import Headers
|
||||||
from werkzeug.test import EnvironBuilder
|
from werkzeug.test import EnvironBuilder
|
||||||
@@ -15,6 +17,59 @@ def test_oauth_convert_request_to_raw_data():
|
|||||||
request = Request(builder.get_environ())
|
request = Request(builder.get_environ())
|
||||||
raw_request_bytes = oauth_handler._convert_request_to_raw_data(request)
|
raw_request_bytes = oauth_handler._convert_request_to_raw_data(request)
|
||||||
|
|
||||||
assert b"GET /test HTTP/1.1" in raw_request_bytes
|
assert b"GET /test? HTTP/1.1" in raw_request_bytes
|
||||||
assert b"Content-Type: application/json" in raw_request_bytes
|
assert b"Content-Type: application/json" in raw_request_bytes
|
||||||
assert b"\r\n\r\n" in raw_request_bytes
|
assert b"\r\n\r\n" in raw_request_bytes
|
||||||
|
|
||||||
|
|
||||||
|
def test_oauth_convert_request_to_raw_data_with_query_params():
|
||||||
|
oauth_handler = OAuthHandler()
|
||||||
|
builder = EnvironBuilder(
|
||||||
|
method="GET",
|
||||||
|
path="/test",
|
||||||
|
query_string="code=abc123&state=xyz789",
|
||||||
|
headers=Headers({"Content-Type": "application/json"}),
|
||||||
|
)
|
||||||
|
request = Request(builder.get_environ())
|
||||||
|
raw_request_bytes = oauth_handler._convert_request_to_raw_data(request)
|
||||||
|
|
||||||
|
assert b"GET /test?code=abc123&state=xyz789 HTTP/1.1" in raw_request_bytes
|
||||||
|
assert b"Content-Type: application/json" in raw_request_bytes
|
||||||
|
assert b"\r\n\r\n" in raw_request_bytes
|
||||||
|
|
||||||
|
|
||||||
|
def test_oauth_convert_request_to_raw_data_with_post_body():
|
||||||
|
oauth_handler = OAuthHandler()
|
||||||
|
builder = EnvironBuilder(
|
||||||
|
method="POST",
|
||||||
|
path="/test",
|
||||||
|
data="param1=value1¶m2=value2",
|
||||||
|
headers=Headers({"Content-Type": "application/x-www-form-urlencoded"}),
|
||||||
|
)
|
||||||
|
request = Request(builder.get_environ())
|
||||||
|
raw_request_bytes = oauth_handler._convert_request_to_raw_data(request)
|
||||||
|
|
||||||
|
assert b"POST /test? HTTP/1.1" in raw_request_bytes
|
||||||
|
assert b"Content-Type: application/x-www-form-urlencoded" in raw_request_bytes
|
||||||
|
assert b"\r\n\r\n" in raw_request_bytes
|
||||||
|
assert b"param1=value1¶m2=value2" in raw_request_bytes
|
||||||
|
|
||||||
|
|
||||||
|
def test_oauth_convert_request_to_raw_data_with_json_body():
|
||||||
|
oauth_handler = OAuthHandler()
|
||||||
|
json_data = {"code": "abc123", "state": "xyz789", "grant_type": "authorization_code"}
|
||||||
|
builder = EnvironBuilder(
|
||||||
|
method="POST",
|
||||||
|
path="/test",
|
||||||
|
data=json.dumps(json_data),
|
||||||
|
headers=Headers({"Content-Type": "application/json"}),
|
||||||
|
)
|
||||||
|
request = Request(builder.get_environ())
|
||||||
|
raw_request_bytes = oauth_handler._convert_request_to_raw_data(request)
|
||||||
|
|
||||||
|
assert b"POST /test? HTTP/1.1" in raw_request_bytes
|
||||||
|
assert b"Content-Type: application/json" in raw_request_bytes
|
||||||
|
assert b"\r\n\r\n" in raw_request_bytes
|
||||||
|
assert b'"code": "abc123"' in raw_request_bytes
|
||||||
|
assert b'"state": "xyz789"' in raw_request_bytes
|
||||||
|
assert b'"grant_type": "authorization_code"' in raw_request_bytes
|
||||||
|
Reference in New Issue
Block a user