orm filter -> where (#22801)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Asuka Minato
2025-07-24 01:57:45 +09:00
committed by GitHub
parent e64e7563f6
commit ef51678c73
161 changed files with 828 additions and 857 deletions

View File

@@ -3,6 +3,7 @@ from datetime import UTC, datetime, timedelta
from flask import request
from flask_restful import Resource
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
@@ -42,17 +43,17 @@ class PassportResource(Resource):
raise WebAppAuthRequiredError()
# get site from db and check if it is normal
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal"))
if not site:
raise NotFound()
# get app from db and check if it is normal and enable_site
app_model = db.session.query(App).filter(App.id == site.app_id).first()
app_model = db.session.scalar(select(App).where(App.id == site.app_id))
if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound()
if user_id:
end_user = (
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
end_user = db.session.scalar(
select(EndUser).where(EndUser.app_id == app_model.id, EndUser.session_id == user_id)
)
if end_user:
@@ -121,11 +122,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
if not user_auth_type:
raise Unauthorized("Missing auth_type in the token.")
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal"))
if not site:
raise NotFound()
app_model = db.session.query(App).filter(App.id == site.app_id).first()
app_model = db.session.scalar(select(App).where(App.id == site.app_id))
if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound()
@@ -140,16 +141,14 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
end_user = None
if end_user_id:
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
if session_id:
end_user = (
db.session.query(EndUser)
.filter(
end_user = db.session.scalar(
select(EndUser).where(
EndUser.session_id == session_id,
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
)
.first()
)
if not end_user:
if not session_id:
@@ -187,8 +186,8 @@ def _exchange_for_public_app_token(app_model, site, token_decoded):
user_id = token_decoded.get("user_id")
end_user = None
if user_id:
end_user = (
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
end_user = db.session.scalar(
select(EndUser).where(EndUser.app_id == app_model.id, EndUser.session_id == user_id)
)
if not end_user:
@@ -224,6 +223,8 @@ def generate_session_id():
"""
while True:
session_id = str(uuid.uuid4())
existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count()
existing_count = db.session.scalar(
select(func.count()).select_from(EndUser).where(EndUser.session_id == session_id)
)
if existing_count == 0:
return session_id

View File

@@ -57,7 +57,7 @@ class AppSiteApi(WebApiResource):
def get(self, app_model, end_user):
"""Retrieve app site info."""
# get site
site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise Forbidden()

View File

@@ -3,6 +3,7 @@ from functools import wraps
from flask import request
from flask_restful import Resource
from sqlalchemy import select
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
@@ -48,8 +49,8 @@ def decode_jwt_token():
decoded = PassportService().verify(tk)
app_code = decoded.get("app_code")
app_id = decoded.get("app_id")
app_model = db.session.query(App).filter(App.id == app_id).first()
site = db.session.query(Site).filter(Site.code == app_code).first()
app_model = db.session.scalar(select(App).where(App.id == app_id))
site = db.session.scalar(select(Site).where(Site.code == app_code))
if not app_model:
raise NotFound()
if not app_code or not site:
@@ -57,7 +58,7 @@ def decode_jwt_token():
if app_model.enable_site is False:
raise BadRequest("Site is disabled.")
end_user_id = decoded.get("end_user_id")
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
if not end_user:
raise NotFound()