Feat/enterprise sso (#3602)

This commit is contained in:
Garfield Dai
2024-04-18 17:33:32 +08:00
committed by GitHub
parent d9f1a8ce9f
commit 4481906be2
30 changed files with 518 additions and 21 deletions

View File

@@ -8,7 +8,7 @@ from typing import Any, Optional
from flask import current_app
from sqlalchemy import func
from werkzeug.exceptions import Forbidden
from werkzeug.exceptions import Unauthorized
from constants.languages import language_timezone_mapping, languages
from events.tenant_event import tenant_was_created
@@ -44,7 +44,7 @@ class AccountService:
return None
if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]:
raise Forbidden('Account is banned or closed.')
raise Unauthorized("Account is banned or closed.")
current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
if current_tenant:
@@ -255,7 +255,7 @@ class TenantService:
"""Get account join tenants"""
return db.session.query(Tenant).join(
TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id
).filter(TenantAccountJoin.account_id == account.id).all()
).filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL).all()
@staticmethod
def get_current_tenant_by_account(account: Account):
@@ -279,7 +279,12 @@ class TenantService:
if tenant_id is None:
raise ValueError("Tenant ID must be provided.")
tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first()
tenant_account_join = db.session.query(TenantAccountJoin).join(Tenant, TenantAccountJoin.tenant_id == Tenant.id).filter(
TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id == tenant_id,
Tenant.status == TenantStatus.NORMAL,
).first()
if not tenant_account_join:
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
else:

View File

View File

@@ -0,0 +1,20 @@
import os
import requests
class EnterpriseRequest:
base_url = os.environ.get('ENTERPRISE_API_URL', 'ENTERPRISE_API_URL')
secret_key = os.environ.get('ENTERPRISE_API_SECRET_KEY', 'ENTERPRISE_API_SECRET_KEY')
@classmethod
def send_request(cls, method, endpoint, json=None, params=None):
headers = {
"Content-Type": "application/json",
"Enterprise-Api-Secret-Key": cls.secret_key
}
url = f"{cls.base_url}{endpoint}"
response = requests.request(method, url, json=json, params=params, headers=headers)
return response.json()

View File

@@ -0,0 +1,28 @@
from flask import current_app
from pydantic import BaseModel
from services.enterprise.enterprise_service import EnterpriseService
class EnterpriseFeatureModel(BaseModel):
sso_enforced_for_signin: bool = False
sso_enforced_for_signin_protocol: str = ''
class EnterpriseFeatureService:
@classmethod
def get_enterprise_features(cls) -> EnterpriseFeatureModel:
features = EnterpriseFeatureModel()
if current_app.config['ENTERPRISE_ENABLED']:
cls._fulfill_params_from_enterprise(features)
return features
@classmethod
def _fulfill_params_from_enterprise(cls, features):
enterprise_info = EnterpriseService.get_info()
features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin']
features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol']

View File

@@ -0,0 +1,8 @@
from services.enterprise.base import EnterpriseRequest
class EnterpriseService:
@classmethod
def get_info(cls):
return EnterpriseRequest.send_request('GET', '/info')

View File

@@ -0,0 +1,60 @@
import logging
from models.account import Account, AccountStatus
from services.account_service import AccountService, TenantService
from services.enterprise.base import EnterpriseRequest
logger = logging.getLogger(__name__)
class EnterpriseSSOService:
@classmethod
def get_sso_saml_login(cls) -> str:
return EnterpriseRequest.send_request('GET', '/sso/saml/login')
@classmethod
def post_sso_saml_acs(cls, saml_response: str) -> str:
response = EnterpriseRequest.send_request('POST', '/sso/saml/acs', json={'SAMLResponse': saml_response})
if 'email' not in response or response['email'] is None:
logger.exception(response)
raise Exception('Saml response is invalid')
return cls.login_with_email(response.get('email'))
@classmethod
def get_sso_oidc_login(cls):
return EnterpriseRequest.send_request('GET', '/sso/oidc/login')
@classmethod
def get_sso_oidc_callback(cls, args: dict):
state_from_query = args['state']
code_from_query = args['code']
state_from_cookies = args['oidc-state']
if state_from_cookies != state_from_query:
raise Exception('invalid state or code')
response = EnterpriseRequest.send_request('GET', '/sso/oidc/callback', params={'code': code_from_query})
if 'email' not in response or response['email'] is None:
logger.exception(response)
raise Exception('OIDC response is invalid')
return cls.login_with_email(response.get('email'))
@classmethod
def login_with_email(cls, email: str) -> str:
account = Account.query.filter_by(email=email).first()
if account is None:
raise Exception('account not found, please contact system admin to invite you to join in a workspace')
if account.status == AccountStatus.BANNED:
raise Exception('account is banned, please contact system admin')
tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0:
raise Exception("workspace not found, please contact system admin to invite you to join in a workspace")
token = AccountService.get_account_jwt_token(account)
return token