Feat/api jwt (#1212)

This commit is contained in:
zxhlyh
2023-09-25 12:49:16 +08:00
committed by GitHub
parent c40ee7e629
commit 227f9fb77d
15 changed files with 142 additions and 339 deletions

View File

@@ -1,8 +1,7 @@
# -*- coding:utf-8 -*-
import os
from datetime import datetime, timedelta
from werkzeug.exceptions import Forbidden
from werkzeug.exceptions import Unauthorized
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
from gevent import monkey
@@ -12,12 +11,11 @@ import logging
import json
import threading
from flask import Flask, request, Response, session
import flask_login
from flask import Flask, request, Response
from flask_cors import CORS
from core.model_providers.providers import hosted
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage, ext_mail, ext_stripe
from extensions.ext_database import db
from extensions.ext_login import login_manager
@@ -27,12 +25,10 @@ from models import model, account, dataset, web, task, source, tool
from events import event_handlers
# DO NOT REMOVE ABOVE
import core
from config import Config, CloudEditionConfig
from commands import register_commands
from models.account import TenantAccountJoin, AccountStatus
from models.model import Account, EndUser, App
from services.account_service import TenantService
from services.account_service import AccountService
from libs.passport import PassportService
import warnings
warnings.simplefilter("ignore", ResourceWarning)
@@ -85,81 +81,33 @@ def initialize_extensions(app):
ext_redis.init_app(app)
ext_storage.init_app(app)
ext_celery.init_app(app)
ext_session.init_app(app)
ext_login.init_app(app)
ext_mail.init_app(app)
ext_sentry.init_app(app)
ext_stripe.init_app(app)
def _create_tenant_for_account(account):
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role='owner')
account.current_tenant = tenant
return tenant
# Flask-Login configuration
@login_manager.user_loader
def load_user(user_id):
"""Load user based on the user_id."""
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
if request.blueprint == 'console':
# Check if the user_id contains a dot, indicating the old format
if '.' in user_id:
tenant_id, account_id = user_id.split('.')
else:
account_id = user_id
auth_header = request.headers.get('Authorization', '')
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
decoded = PassportService().verify(auth_token)
user_id = decoded.get('user_id')
account = db.session.query(Account).filter(Account.id == account_id).first()
if account:
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
raise Forbidden('Account is banned or closed.')
workspace_id = session.get('workspace_id')
if workspace_id:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id == workspace_id
).first()
if not tenant_account_join:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id).first()
if tenant_account_join:
account.current_tenant_id = tenant_account_join.tenant_id
else:
_create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
else:
account.current_tenant_id = workspace_id
else:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id).first()
if tenant_account_join:
account.current_tenant_id = tenant_account_join.tenant_id
else:
_create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
current_time = datetime.utcnow()
# update last_active_at when last_active_at is more than 10 minutes ago
if current_time - account.last_active_at > timedelta(minutes=10):
account.last_active_at = current_time
db.session.commit()
# Log in the user with the updated user_id
flask_login.login_user(account, remember=True)
return account
return AccountService.load_user(user_id)
else:
return None
@login_manager.unauthorized_handler
def unauthorized_handler():
"""Handle unauthorized requests."""
@@ -216,6 +164,7 @@ if app.config['TESTING']:
@app.after_request
def after_request(response):
"""Add Version headers to the response."""
response.set_cookie('remember_token', '', expires=0)
response.headers.add('X-Version', app.config['CURRENT_VERSION'])
response.headers.add('X-Env', app.config['DEPLOY_ENV'])
return response