From 2a29c61041d1fd8e01106a5ae5bcbae41696d993 Mon Sep 17 00:00:00 2001 From: Yongtao Huang Date: Wed, 27 Aug 2025 17:46:52 +0800 Subject: [PATCH] Refactor: replace count() > 0 check with exists() (#24583) Co-authored-by: Yongtao Huang <99629139+hyongtao-db@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/message.py | 14 +++++----- api/models/model.py | 4 +-- api/models/workflow.py | 15 ++++++----- api/services/dataset_service.py | 8 +++--- .../tools/builtin_tools_manage_service.py | 27 ++++++++++++------- api/services/workflow_service.py | 11 ++++---- .../disable_annotation_reply_task.py | 5 ++-- 7 files changed, 44 insertions(+), 40 deletions(-) diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index f61ddb464..05b668b80 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -3,6 +3,7 @@ import logging from flask_login import current_user from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx.inputs import int_range +from sqlalchemy import exists, select from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from controllers.console import api @@ -94,21 +95,18 @@ class ChatMessageListApi(Resource): .all() ) - has_more = False if len(history_messages) == args["limit"]: current_page_first_message = history_messages[-1] - rest_count = ( - db.session.query(Message) - .where( + + has_more = db.session.scalar( + select( + exists().where( Message.conversation_id == conversation.id, Message.created_at < current_page_first_message.created_at, Message.id != current_page_first_message.id, ) - .count() ) - - if rest_count > 0: - has_more = True + ) history_messages = list(reversed(history_messages)) diff --git a/api/models/model.py b/api/models/model.py index aeb2cad62..ed1be14a6 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: import sqlalchemy as sa from flask import request from flask_login import UserMixin -from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, func, text +from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config @@ -1553,7 +1553,7 @@ class ApiToken(Base): def generate_api_key(prefix, n): while True: result = prefix + generate_string(n) - if db.session.query(ApiToken).where(ApiToken.token == result).count() > 0: + if db.session.scalar(select(exists().where(ApiToken.token == result))): continue return result diff --git a/api/models/workflow.py b/api/models/workflow.py index 2c1b86738..4d0089fa4 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union from uuid import uuid4 import sqlalchemy as sa -from sqlalchemy import DateTime, orm +from sqlalchemy import DateTime, exists, orm, select from core.file.constants import maybe_file_object from core.file.models import File @@ -336,12 +336,13 @@ class Workflow(Base): """ from models.tools import WorkflowToolProvider - return ( - db.session.query(WorkflowToolProvider) - .where(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id) - .count() - > 0 + stmt = select( + exists().where( + WorkflowToolProvider.tenant_id == self.tenant_id, + WorkflowToolProvider.app_id == self.app_id, + ) ) + return db.session.execute(stmt).scalar_one() @property def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: @@ -921,7 +922,7 @@ def _naive_utc_datetime(): class WorkflowDraftVariable(Base): """`WorkflowDraftVariable` record variables and outputs generated during - debugging worfklow or chatflow. + debugging workflow or chatflow. IMPORTANT: This model maintains multiple invariant rules that must be preserved. Do not instantiate this class directly with the constructor. diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 19119271e..84860fd17 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -9,7 +9,7 @@ from collections import Counter from typing import Any, Literal, Optional from flask_login import current_user -from sqlalchemy import func, select +from sqlalchemy import exists, func, select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -655,10 +655,8 @@ class DatasetService: @staticmethod def dataset_use_check(dataset_id) -> bool: - count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count() - if count > 0: - return True - return False + stmt = select(exists().where(AppDatasetJoin.dataset_id == dataset_id)) + return db.session.execute(stmt).scalar_one() @staticmethod def check_dataset_permission(dataset, user): diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 84b958023..71bc50017 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -5,6 +5,7 @@ from collections.abc import Mapping from pathlib import Path from typing import Any, Optional +from sqlalchemy import exists, select from sqlalchemy.orm import Session from configs import dify_config @@ -190,11 +191,14 @@ class BuiltinToolManageService: # update name if provided if name and name != db_provider.name: # check if the name is already used - if ( - session.query(BuiltinToolProvider) - .filter_by(tenant_id=tenant_id, provider=provider, name=name) - .count() - > 0 + if session.scalar( + select( + exists().where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + BuiltinToolProvider.name == name, + ) + ) ): raise ValueError(f"the credential name '{name}' is already used") @@ -246,11 +250,14 @@ class BuiltinToolManageService: ) else: # check if the name is already used - if ( - session.query(BuiltinToolProvider) - .filter_by(tenant_id=tenant_id, provider=provider, name=name) - .count() - > 0 + if session.scalar( + select( + exists().where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + BuiltinToolProvider.name == name, + ) + ) ): raise ValueError(f"the credential name '{name}' is already used") diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index d2715a61f..3a6837978 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -5,7 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence from typing import Any, Optional, cast from uuid import uuid4 -from sqlalchemy import select +from sqlalchemy import exists, select from sqlalchemy.orm import Session, sessionmaker from core.app.app_config.entities import VariableEntityType @@ -87,15 +87,14 @@ class WorkflowService: ) def is_workflow_exist(self, app_model: App) -> bool: - return ( - db.session.query(Workflow) - .where( + stmt = select( + exists().where( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == Workflow.VERSION_DRAFT, ) - .count() - ) > 0 + ) + return db.session.execute(stmt).scalar_one() def get_draft_workflow(self, app_model: App) -> Optional[Workflow]: """ diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index c824059bf..c0020b29e 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import exists, select from core.rag.datasource.vdb.vector_factory import Vector from extensions.ext_database import db @@ -22,7 +23,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): start_at = time.perf_counter() # get app info app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() - annotations_count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).count() + annotations_exists = db.session.scalar(select(exists().where(MessageAnnotation.app_id == app_id))) if not app: logger.info(click.style(f"App not found: {app_id}", fg="red")) db.session.close() @@ -47,7 +48,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): ) try: - if annotations_count > 0: + if annotations_exists: vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) vector.delete() except Exception: