diff --git a/api/models/account.py b/api/models/account.py index 0c5bb6ff0..6db1381df 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -1,12 +1,12 @@ import enum import json from datetime import datetime -from typing import Optional, cast +from typing import Optional import sqlalchemy as sa from flask_login import UserMixin from sqlalchemy import DateTime, String, func, select -from sqlalchemy.orm import Mapped, mapped_column, reconstructor +from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor from models.base import Base @@ -118,10 +118,24 @@ class Account(UserMixin, Base): @current_tenant.setter def current_tenant(self, tenant: "Tenant"): - ta = db.session.scalar(select(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).limit(1)) - if ta: - self.role = TenantAccountRole(ta.role) - self._current_tenant = tenant + with Session(db.engine, expire_on_commit=False) as session: + tenant_join_query = select(TenantAccountJoin).where( + TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == self.id + ) + tenant_join = session.scalar(tenant_join_query) + tenant_query = select(Tenant).where(Tenant.id == tenant.id) + # TODO: A workaround to reload the tenant with `expire_on_commit=False`, allowing + # access to it after the session has been closed. + # This prevents `DetachedInstanceError` when accessing the tenant outside + # the session's lifecycle. + # (The `tenant` argument is typically loaded by `db.session` without the + # `expire_on_commit=False` flag, meaning its lifetime is tied to the web + # request's lifecycle.) + tenant_reloaded = session.scalars(tenant_query).one() + + if tenant_join: + self.role = TenantAccountRole(tenant_join.role) + self._current_tenant = tenant_reloaded return self._current_tenant = None @@ -130,23 +144,19 @@ class Account(UserMixin, Base): return self._current_tenant.id if self._current_tenant else None def set_tenant_id(self, tenant_id: str): - tenant_account_join = cast( - tuple[Tenant, TenantAccountJoin], - ( - db.session.query(Tenant, TenantAccountJoin) - .where(Tenant.id == tenant_id) - .where(TenantAccountJoin.tenant_id == Tenant.id) - .where(TenantAccountJoin.account_id == self.id) - .one_or_none() - ), + query = ( + select(Tenant, TenantAccountJoin) + .where(Tenant.id == tenant_id) + .where(TenantAccountJoin.tenant_id == Tenant.id) + .where(TenantAccountJoin.account_id == self.id) ) - - if not tenant_account_join: - return - - tenant, join = tenant_account_join - self.role = TenantAccountRole(join.role) - self._current_tenant = tenant + with Session(db.engine, expire_on_commit=False) as session: + tenant_account_join = session.execute(query).first() + if not tenant_account_join: + return + tenant, join = tenant_account_join + self.role = TenantAccountRole(join.role) + self._current_tenant = tenant @property def current_role(self):