refactor: Use typed SQLAlchemy base model and fix type errors (#19980)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -3,11 +3,14 @@ import json
|
||||
import flask_login # type: ignore
|
||||
from flask import Response, request
|
||||
from flask_login import user_loaded_from_request, user_logged_in
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
import contexts
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
from services.account_service import AccountService
|
||||
|
||||
login_manager = flask_login.LoginManager()
|
||||
@@ -17,34 +20,48 @@ login_manager = flask_login.LoginManager()
|
||||
@login_manager.request_loader
|
||||
def load_user_from_request(request_from_flask_login):
|
||||
"""Load user based on the request."""
|
||||
if request.blueprint not in {"console", "inner_api"}:
|
||||
return None
|
||||
# Check if the user_id contains a dot, indicating the old format
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header:
|
||||
auth_token = request.args.get("_token")
|
||||
if not auth_token:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
else:
|
||||
auth_token: str | None = None
|
||||
if auth_header:
|
||||
if " " not in auth_header:
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme, auth_token = auth_header.split(maxsplit=1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
else:
|
||||
auth_token = request.args.get("_token")
|
||||
|
||||
decoded = PassportService().verify(auth_token)
|
||||
user_id = decoded.get("user_id")
|
||||
if request.blueprint in {"console", "inner_api"}:
|
||||
if not auth_token:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
decoded = PassportService().verify(auth_token)
|
||||
user_id = decoded.get("user_id")
|
||||
if not user_id:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
|
||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||
return logged_in_account
|
||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||
return logged_in_account
|
||||
elif request.blueprint == "web":
|
||||
decoded = PassportService().verify(auth_token)
|
||||
end_user_id = decoded.get("end_user_id")
|
||||
if not end_user_id:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first()
|
||||
if not end_user:
|
||||
raise NotFound("End user not found.")
|
||||
return end_user
|
||||
|
||||
|
||||
@user_logged_in.connect
|
||||
@user_loaded_from_request.connect
|
||||
def on_user_logged_in(_sender, user):
|
||||
"""Called when a user logged in."""
|
||||
if user:
|
||||
"""Called when a user logged in.
|
||||
|
||||
Note: AccountService.load_logged_in_account will populate user.current_tenant_id
|
||||
through the load_user method, which calls account.set_tenant_id().
|
||||
"""
|
||||
if user and isinstance(user, Account) and user.current_tenant_id:
|
||||
contexts.tenant_id.set(user.current_tenant_id)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user