diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py index 4077ec38d..b84dd0afc 100644 --- a/api/services/plugin/oauth_service.py +++ b/api/services/plugin/oauth_service.py @@ -8,9 +8,10 @@ from extensions.ext_redis import redis_client class OAuthProxyService(BasePluginClient): # Default max age for proxy context parameter in seconds __MAX_AGE__ = 5 * 60 # 5 minutes + __KEY_PREFIX__ = "oauth_proxy_context:" @staticmethod - def create_proxy_context(user_id, tenant_id, plugin_id, provider): + def create_proxy_context(user_id: str, tenant_id: str, plugin_id: str, provider: str): """ Create a proxy context for an OAuth 2.0 authorization request. @@ -23,26 +24,22 @@ class OAuthProxyService(BasePluginClient): is used to verify the state, ensuring the request's integrity and authenticity, and mitigating replay attacks. """ - seconds, _ = redis_client.time() context_id = str(uuid.uuid4()) data = { "user_id": user_id, "plugin_id": plugin_id, "tenant_id": tenant_id, "provider": provider, - # encode redis time to avoid distribution time skew - "timestamp": seconds, } - # ignore nonce collision redis_client.setex( - f"oauth_proxy_context:{context_id}", + f"{OAuthProxyService.__KEY_PREFIX__}{context_id}", OAuthProxyService.__MAX_AGE__, json.dumps(data), ) return context_id @staticmethod - def use_proxy_context(context_id, max_age=__MAX_AGE__): + def use_proxy_context(context_id: str): """ Validate the proxy context parameter. This checks if the context_id is valid and not expired. @@ -50,12 +47,7 @@ class OAuthProxyService(BasePluginClient): if not context_id: raise ValueError("context_id is required") # get data from redis - data = redis_client.getdel(f"oauth_proxy_context:{context_id}") + data = redis_client.getdel(f"{OAuthProxyService.__KEY_PREFIX__}{context_id}") if not data: raise ValueError("context_id is invalid") - # check if data is expired - seconds, _ = redis_client.time() - state = json.loads(data) - if state.get("timestamp") < seconds - max_age: - raise ValueError("context_id is expired") - return state + return json.loads(data)