Refactor: replace count() > 0 check with exists() (#24583)

Co-authored-by: Yongtao Huang <99629139+hyongtao-db@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Yongtao Huang
2025-08-27 17:46:52 +08:00
committed by GitHub
parent 34b041e9f0
commit 2a29c61041
7 changed files with 44 additions and 40 deletions

View File

@@ -3,6 +3,7 @@ import logging
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
from sqlalchemy import exists, select
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from controllers.console import api from controllers.console import api
@@ -94,21 +95,18 @@ class ChatMessageListApi(Resource):
.all() .all()
) )
has_more = False
if len(history_messages) == args["limit"]: if len(history_messages) == args["limit"]:
current_page_first_message = history_messages[-1] current_page_first_message = history_messages[-1]
rest_count = (
db.session.query(Message) has_more = db.session.scalar(
.where( select(
exists().where(
Message.conversation_id == conversation.id, Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at, Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id, Message.id != current_page_first_message.id,
) )
.count()
) )
)
if rest_count > 0:
has_more = True
history_messages = list(reversed(history_messages)) history_messages = list(reversed(history_messages))

View File

@@ -17,7 +17,7 @@ if TYPE_CHECKING:
import sqlalchemy as sa import sqlalchemy as sa
from flask import request from flask import request
from flask_login import UserMixin from flask_login import UserMixin
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, func, text from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config from configs import dify_config
@@ -1553,7 +1553,7 @@ class ApiToken(Base):
def generate_api_key(prefix, n): def generate_api_key(prefix, n):
while True: while True:
result = prefix + generate_string(n) result = prefix + generate_string(n)
if db.session.query(ApiToken).where(ApiToken.token == result).count() > 0: if db.session.scalar(select(exists().where(ApiToken.token == result))):
continue continue
return result return result

View File

@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import uuid4 from uuid import uuid4
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import DateTime, orm from sqlalchemy import DateTime, exists, orm, select
from core.file.constants import maybe_file_object from core.file.constants import maybe_file_object
from core.file.models import File from core.file.models import File
@@ -336,12 +336,13 @@ class Workflow(Base):
""" """
from models.tools import WorkflowToolProvider from models.tools import WorkflowToolProvider
return ( stmt = select(
db.session.query(WorkflowToolProvider) exists().where(
.where(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id) WorkflowToolProvider.tenant_id == self.tenant_id,
.count() WorkflowToolProvider.app_id == self.app_id,
> 0 )
) )
return db.session.execute(stmt).scalar_one()
@property @property
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
@@ -921,7 +922,7 @@ def _naive_utc_datetime():
class WorkflowDraftVariable(Base): class WorkflowDraftVariable(Base):
"""`WorkflowDraftVariable` record variables and outputs generated during """`WorkflowDraftVariable` record variables and outputs generated during
debugging worfklow or chatflow. debugging workflow or chatflow.
IMPORTANT: This model maintains multiple invariant rules that must be preserved. IMPORTANT: This model maintains multiple invariant rules that must be preserved.
Do not instantiate this class directly with the constructor. Do not instantiate this class directly with the constructor.

View File

@@ -9,7 +9,7 @@ from collections import Counter
from typing import Any, Literal, Optional from typing import Any, Literal, Optional
from flask_login import current_user from flask_login import current_user
from sqlalchemy import func, select from sqlalchemy import exists, func, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
@@ -655,10 +655,8 @@ class DatasetService:
@staticmethod @staticmethod
def dataset_use_check(dataset_id) -> bool: def dataset_use_check(dataset_id) -> bool:
count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count() stmt = select(exists().where(AppDatasetJoin.dataset_id == dataset_id))
if count > 0: return db.session.execute(stmt).scalar_one()
return True
return False
@staticmethod @staticmethod
def check_dataset_permission(dataset, user): def check_dataset_permission(dataset, user):

View File

@@ -5,6 +5,7 @@ from collections.abc import Mapping
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Any, Optional
from sqlalchemy import exists, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
@@ -190,11 +191,14 @@ class BuiltinToolManageService:
# update name if provided # update name if provided
if name and name != db_provider.name: if name and name != db_provider.name:
# check if the name is already used # check if the name is already used
if ( if session.scalar(
session.query(BuiltinToolProvider) select(
.filter_by(tenant_id=tenant_id, provider=provider, name=name) exists().where(
.count() BuiltinToolProvider.tenant_id == tenant_id,
> 0 BuiltinToolProvider.provider == provider,
BuiltinToolProvider.name == name,
)
)
): ):
raise ValueError(f"the credential name '{name}' is already used") raise ValueError(f"the credential name '{name}' is already used")
@@ -246,11 +250,14 @@ class BuiltinToolManageService:
) )
else: else:
# check if the name is already used # check if the name is already used
if ( if session.scalar(
session.query(BuiltinToolProvider) select(
.filter_by(tenant_id=tenant_id, provider=provider, name=name) exists().where(
.count() BuiltinToolProvider.tenant_id == tenant_id,
> 0 BuiltinToolProvider.provider == provider,
BuiltinToolProvider.name == name,
)
)
): ):
raise ValueError(f"the credential name '{name}' is already used") raise ValueError(f"the credential name '{name}' is already used")

View File

@@ -5,7 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
from typing import Any, Optional, cast from typing import Any, Optional, cast
from uuid import uuid4 from uuid import uuid4
from sqlalchemy import select from sqlalchemy import exists, select
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from core.app.app_config.entities import VariableEntityType from core.app.app_config.entities import VariableEntityType
@@ -87,15 +87,14 @@ class WorkflowService:
) )
def is_workflow_exist(self, app_model: App) -> bool: def is_workflow_exist(self, app_model: App) -> bool:
return ( stmt = select(
db.session.query(Workflow) exists().where(
.where(
Workflow.tenant_id == app_model.tenant_id, Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id, Workflow.app_id == app_model.id,
Workflow.version == Workflow.VERSION_DRAFT, Workflow.version == Workflow.VERSION_DRAFT,
) )
.count() )
) > 0 return db.session.execute(stmt).scalar_one()
def get_draft_workflow(self, app_model: App) -> Optional[Workflow]: def get_draft_workflow(self, app_model: App) -> Optional[Workflow]:
""" """

View File

@@ -3,6 +3,7 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from sqlalchemy import exists, select
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db from extensions.ext_database import db
@@ -22,7 +23,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
start_at = time.perf_counter() start_at = time.perf_counter()
# get app info # get app info
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
annotations_count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).count() annotations_exists = db.session.scalar(select(exists().where(MessageAnnotation.app_id == app_id)))
if not app: if not app:
logger.info(click.style(f"App not found: {app_id}", fg="red")) logger.info(click.style(f"App not found: {app_id}", fg="red"))
db.session.close() db.session.close()
@@ -47,7 +48,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
) )
try: try:
if annotations_count > 0: if annotations_exists:
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
vector.delete() vector.delete()
except Exception: except Exception: