Files
dify/api/extensions/ext_login.py
2025-05-26 18:22:01 +08:00

101 lines
3.8 KiB
Python

import json
import flask_login # type: ignore
from flask import Response, request
from flask_login import user_loaded_from_request, user_logged_in
from werkzeug.exceptions import NotFound, Unauthorized
import contexts
from configs import dify_config
from dify_app import DifyApp
from extensions.ext_database import db
from libs.passport import PassportService
from models.account import Account, Tenant, TenantAccountJoin
from models.model import EndUser
from services.account_service import AccountService
login_manager = flask_login.LoginManager()
# Flask-Login configuration
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
auth_header = request.headers.get("Authorization", "")
auth_token: str | None = None
if auth_header:
if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, auth_token = auth_header.split(maxsplit=1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
else:
auth_token = request.args.get("_token")
# Check for admin API key authentication first
if dify_config.ADMIN_API_KEY_ENABLE and auth_header:
admin_api_key = dify_config.ADMIN_API_KEY
if admin_api_key and admin_api_key == auth_token:
workspace_id = request.headers.get("X-WORKSPACE-ID")
if workspace_id:
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == workspace_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.role == "owner")
.one_or_none()
)
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.query(Account).filter_by(id=ta.account_id).first()
if account:
account.current_tenant = tenant
return account
if request.blueprint in {"console", "inner_api"}:
if not auth_token:
raise Unauthorized("Invalid Authorization token.")
decoded = PassportService().verify(auth_token)
user_id = decoded.get("user_id")
if not user_id:
raise Unauthorized("Invalid Authorization token.")
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
return logged_in_account
elif request.blueprint == "web":
decoded = PassportService().verify(auth_token)
end_user_id = decoded.get("end_user_id")
if not end_user_id:
raise Unauthorized("Invalid Authorization token.")
end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first()
if not end_user:
raise NotFound("End user not found.")
return end_user
@user_logged_in.connect
@user_loaded_from_request.connect
def on_user_logged_in(_sender, user):
"""Called when a user logged in.
Note: AccountService.load_logged_in_account will populate user.current_tenant_id
through the load_user method, which calls account.set_tenant_id().
"""
if user and isinstance(user, Account) and user.current_tenant_id:
contexts.tenant_id.set(user.current_tenant_id)
@login_manager.unauthorized_handler
def unauthorized_handler():
"""Handle unauthorized requests."""
return Response(
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
status=401,
content_type="application/json",
)
def init_app(app: DifyApp):
login_manager.init_app(app)