diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index cc4b5f65b..67d48319d 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -291,27 +291,28 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] if not user_id: user_id = "DEFAULT-USER" - end_user = ( - db.session.query(EndUser) - .where( - EndUser.tenant_id == app_model.tenant_id, - EndUser.app_id == app_model.id, - EndUser.session_id == user_id, - EndUser.type == "service_api", + with Session(db.engine, expire_on_commit=False) as session: + end_user = ( + session.query(EndUser) + .where( + EndUser.tenant_id == app_model.tenant_id, + EndUser.app_id == app_model.id, + EndUser.session_id == user_id, + EndUser.type == "service_api", + ) + .first() ) - .first() - ) - if end_user is None: - end_user = EndUser( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - type="service_api", - is_anonymous=user_id == "DEFAULT-USER", - session_id=user_id, - ) - db.session.add(end_user) - db.session.commit() + if end_user is None: + end_user = EndUser( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type="service_api", + is_anonymous=user_id == "DEFAULT-USER", + session_id=user_id, + ) + session.add(end_user) + session.commit() return end_user diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 94fa5d562..1fc8916ca 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -4,6 +4,7 @@ from functools import wraps from flask import request from flask_restx import Resource from sqlalchemy import select +from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, NotFound, Unauthorized from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError @@ -49,18 +50,19 @@ def decode_jwt_token(): decoded = PassportService().verify(tk) app_code = decoded.get("app_code") app_id = decoded.get("app_id") - app_model = db.session.scalar(select(App).where(App.id == app_id)) - site = db.session.scalar(select(Site).where(Site.code == app_code)) - if not app_model: - raise NotFound() - if not app_code or not site: - raise BadRequest("Site URL is no longer valid.") - if app_model.enable_site is False: - raise BadRequest("Site is disabled.") - end_user_id = decoded.get("end_user_id") - end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id)) - if not end_user: - raise NotFound() + with Session(db.engine, expire_on_commit=False) as session: + app_model = session.scalar(select(App).where(App.id == app_id)) + site = session.scalar(select(Site).where(Site.code == app_code)) + if not app_model: + raise NotFound() + if not app_code or not site: + raise BadRequest("Site URL is no longer valid.") + if app_model.enable_site is False: + raise BadRequest("Site is disabled.") + end_user_id = decoded.get("end_user_id") + end_user = session.scalar(select(EndUser).where(EndUser.id == end_user_id)) + if not end_user: + raise NotFound() # for enterprise webapp auth app_web_auth_enabled = False