orm filter -> where (#22801)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -3,6 +3,7 @@ from datetime import UTC, datetime, timedelta
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@@ -42,17 +43,17 @@ class PassportResource(Resource):
|
||||
raise WebAppAuthRequiredError()
|
||||
|
||||
# get site from db and check if it is normal
|
||||
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
|
||||
site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal"))
|
||||
if not site:
|
||||
raise NotFound()
|
||||
# get app from db and check if it is normal and enable_site
|
||||
app_model = db.session.query(App).filter(App.id == site.app_id).first()
|
||||
app_model = db.session.scalar(select(App).where(App.id == site.app_id))
|
||||
if not app_model or app_model.status != "normal" or not app_model.enable_site:
|
||||
raise NotFound()
|
||||
|
||||
if user_id:
|
||||
end_user = (
|
||||
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
|
||||
end_user = db.session.scalar(
|
||||
select(EndUser).where(EndUser.app_id == app_model.id, EndUser.session_id == user_id)
|
||||
)
|
||||
|
||||
if end_user:
|
||||
@@ -121,11 +122,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
|
||||
if not user_auth_type:
|
||||
raise Unauthorized("Missing auth_type in the token.")
|
||||
|
||||
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
|
||||
site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal"))
|
||||
if not site:
|
||||
raise NotFound()
|
||||
|
||||
app_model = db.session.query(App).filter(App.id == site.app_id).first()
|
||||
app_model = db.session.scalar(select(App).where(App.id == site.app_id))
|
||||
if not app_model or app_model.status != "normal" or not app_model.enable_site:
|
||||
raise NotFound()
|
||||
|
||||
@@ -140,16 +141,14 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
|
||||
|
||||
end_user = None
|
||||
if end_user_id:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
|
||||
if session_id:
|
||||
end_user = (
|
||||
db.session.query(EndUser)
|
||||
.filter(
|
||||
end_user = db.session.scalar(
|
||||
select(EndUser).where(
|
||||
EndUser.session_id == session_id,
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.app_id == app_model.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not end_user:
|
||||
if not session_id:
|
||||
@@ -187,8 +186,8 @@ def _exchange_for_public_app_token(app_model, site, token_decoded):
|
||||
user_id = token_decoded.get("user_id")
|
||||
end_user = None
|
||||
if user_id:
|
||||
end_user = (
|
||||
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
|
||||
end_user = db.session.scalar(
|
||||
select(EndUser).where(EndUser.app_id == app_model.id, EndUser.session_id == user_id)
|
||||
)
|
||||
|
||||
if not end_user:
|
||||
@@ -224,6 +223,8 @@ def generate_session_id():
|
||||
"""
|
||||
while True:
|
||||
session_id = str(uuid.uuid4())
|
||||
existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count()
|
||||
existing_count = db.session.scalar(
|
||||
select(func.count()).select_from(EndUser).where(EndUser.session_id == session_id)
|
||||
)
|
||||
if existing_count == 0:
|
||||
return session_id
|
||||
|
@@ -57,7 +57,7 @@ class AppSiteApi(WebApiResource):
|
||||
def get(self, app_model, end_user):
|
||||
"""Retrieve app site info."""
|
||||
# get site
|
||||
site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
@@ -3,6 +3,7 @@ from functools import wraps
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
|
||||
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
|
||||
@@ -48,8 +49,8 @@ def decode_jwt_token():
|
||||
decoded = PassportService().verify(tk)
|
||||
app_code = decoded.get("app_code")
|
||||
app_id = decoded.get("app_id")
|
||||
app_model = db.session.query(App).filter(App.id == app_id).first()
|
||||
site = db.session.query(Site).filter(Site.code == app_code).first()
|
||||
app_model = db.session.scalar(select(App).where(App.id == app_id))
|
||||
site = db.session.scalar(select(Site).where(Site.code == app_code))
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
if not app_code or not site:
|
||||
@@ -57,7 +58,7 @@ def decode_jwt_token():
|
||||
if app_model.enable_site is False:
|
||||
raise BadRequest("Site is disabled.")
|
||||
end_user_id = decoded.get("end_user_id")
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
|
||||
if not end_user:
|
||||
raise NotFound()
|
||||
|
||||
|
Reference in New Issue
Block a user