feat: oauth provider (#24206)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: yessenia <yessenia.contact@gmail.com>
This commit is contained in:
Junyan Qin (Chin)
2025-08-29 14:10:51 +08:00
committed by GitHub
parent 3d5a4df9d0
commit f32e176d6a
32 changed files with 757 additions and 22 deletions

View File

@@ -0,0 +1,94 @@
import enum
import uuid
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import Account
from models.model import OAuthProviderApp
from services.account_service import AccountService
class OAuthGrantType(enum.StrEnum):
AUTHORIZATION_CODE = "authorization_code"
REFRESH_TOKEN = "refresh_token"
OAUTH_AUTHORIZATION_CODE_REDIS_KEY = "oauth_provider:{client_id}:authorization_code:{code}"
OAUTH_ACCESS_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:access_token:{token}"
OAUTH_ACCESS_TOKEN_EXPIRES_IN = 60 * 60 * 12 # 12 hours
OAUTH_REFRESH_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:refresh_token:{token}"
OAUTH_REFRESH_TOKEN_EXPIRES_IN = 60 * 60 * 24 * 30 # 30 days
class OAuthServerService:
@staticmethod
def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None:
query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id)
with Session(db.engine) as session:
return session.execute(query).scalar_one_or_none()
@staticmethod
def sign_oauth_authorization_code(client_id: str, user_account_id: str) -> str:
code = str(uuid.uuid4())
redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code)
redis_client.set(redis_key, user_account_id, ex=60 * 10) # 10 minutes
return code
@staticmethod
def sign_oauth_access_token(
grant_type: OAuthGrantType,
code: str = "",
client_id: str = "",
refresh_token: str = "",
) -> tuple[str, str]:
match grant_type:
case OAuthGrantType.AUTHORIZATION_CODE:
redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code)
user_account_id = redis_client.get(redis_key)
if not user_account_id:
raise BadRequest("invalid code")
# delete code
redis_client.delete(redis_key)
access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id)
refresh_token = OAuthServerService._sign_oauth_refresh_token(client_id, user_account_id)
return access_token, refresh_token
case OAuthGrantType.REFRESH_TOKEN:
redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=refresh_token)
user_account_id = redis_client.get(redis_key)
if not user_account_id:
raise BadRequest("invalid refresh token")
access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id)
return access_token, refresh_token
@staticmethod
def _sign_oauth_access_token(client_id: str, user_account_id: str) -> str:
token = str(uuid.uuid4())
redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
redis_client.set(redis_key, user_account_id, ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN)
return token
@staticmethod
def _sign_oauth_refresh_token(client_id: str, user_account_id: str) -> str:
token = str(uuid.uuid4())
redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
redis_client.set(redis_key, user_account_id, ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN)
return token
@staticmethod
def validate_oauth_access_token(client_id: str, token: str) -> Account | None:
redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
user_account_id = redis_client.get(redis_key)
if not user_account_id:
return None
user_id_str = user_account_id.decode("utf-8")
return AccountService.load_user(user_id_str)