Refactor: replace count() > 0 check with exists() (#24583)

Co-authored-by: Yongtao Huang <99629139+hyongtao-db@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Yongtao Huang
2025-08-27 17:46:52 +08:00
committed by GitHub
parent 34b041e9f0
commit 2a29c61041
7 changed files with 44 additions and 40 deletions

View File

@@ -3,6 +3,7 @@ import logging
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy import exists, select
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from controllers.console import api
@@ -94,21 +95,18 @@ class ChatMessageListApi(Resource):
.all()
)
has_more = False
if len(history_messages) == args["limit"]:
current_page_first_message = history_messages[-1]
rest_count = (
db.session.query(Message)
.where(
has_more = db.session.scalar(
select(
exists().where(
Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id,
)
.count()
)
if rest_count > 0:
has_more = True
)
history_messages = list(reversed(history_messages))

View File

@@ -17,7 +17,7 @@ if TYPE_CHECKING:
import sqlalchemy as sa
from flask import request
from flask_login import UserMixin
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, func, text
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config
@@ -1553,7 +1553,7 @@ class ApiToken(Base):
def generate_api_key(prefix, n):
while True:
result = prefix + generate_string(n)
if db.session.query(ApiToken).where(ApiToken.token == result).count() > 0:
if db.session.scalar(select(exists().where(ApiToken.token == result))):
continue
return result

View File

@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import uuid4
import sqlalchemy as sa
from sqlalchemy import DateTime, orm
from sqlalchemy import DateTime, exists, orm, select
from core.file.constants import maybe_file_object
from core.file.models import File
@@ -336,12 +336,13 @@ class Workflow(Base):
"""
from models.tools import WorkflowToolProvider
return (
db.session.query(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id)
.count()
> 0
stmt = select(
exists().where(
WorkflowToolProvider.tenant_id == self.tenant_id,
WorkflowToolProvider.app_id == self.app_id,
)
)
return db.session.execute(stmt).scalar_one()
@property
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
@@ -921,7 +922,7 @@ def _naive_utc_datetime():
class WorkflowDraftVariable(Base):
"""`WorkflowDraftVariable` record variables and outputs generated during
debugging worfklow or chatflow.
debugging workflow or chatflow.
IMPORTANT: This model maintains multiple invariant rules that must be preserved.
Do not instantiate this class directly with the constructor.

View File

@@ -9,7 +9,7 @@ from collections import Counter
from typing import Any, Literal, Optional
from flask_login import current_user
from sqlalchemy import func, select
from sqlalchemy import exists, func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@@ -655,10 +655,8 @@ class DatasetService:
@staticmethod
def dataset_use_check(dataset_id) -> bool:
count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count()
if count > 0:
return True
return False
stmt = select(exists().where(AppDatasetJoin.dataset_id == dataset_id))
return db.session.execute(stmt).scalar_one()
@staticmethod
def check_dataset_permission(dataset, user):

View File

@@ -5,6 +5,7 @@ from collections.abc import Mapping
from pathlib import Path
from typing import Any, Optional
from sqlalchemy import exists, select
from sqlalchemy.orm import Session
from configs import dify_config
@@ -190,11 +191,14 @@ class BuiltinToolManageService:
# update name if provided
if name and name != db_provider.name:
# check if the name is already used
if (
session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
.count()
> 0
if session.scalar(
select(
exists().where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
BuiltinToolProvider.name == name,
)
)
):
raise ValueError(f"the credential name '{name}' is already used")
@@ -246,11 +250,14 @@ class BuiltinToolManageService:
)
else:
# check if the name is already used
if (
session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
.count()
> 0
if session.scalar(
select(
exists().where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
BuiltinToolProvider.name == name,
)
)
):
raise ValueError(f"the credential name '{name}' is already used")

View File

@@ -5,7 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
from typing import Any, Optional, cast
from uuid import uuid4
from sqlalchemy import select
from sqlalchemy import exists, select
from sqlalchemy.orm import Session, sessionmaker
from core.app.app_config.entities import VariableEntityType
@@ -87,15 +87,14 @@ class WorkflowService:
)
def is_workflow_exist(self, app_model: App) -> bool:
return (
db.session.query(Workflow)
.where(
stmt = select(
exists().where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.version == Workflow.VERSION_DRAFT,
)
.count()
) > 0
)
return db.session.execute(stmt).scalar_one()
def get_draft_workflow(self, app_model: App) -> Optional[Workflow]:
"""

View File

@@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import exists, select
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
@@ -22,7 +23,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
start_at = time.perf_counter()
# get app info
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
annotations_count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).count()
annotations_exists = db.session.scalar(select(exists().where(MessageAnnotation.app_id == app_id)))
if not app:
logger.info(click.style(f"App not found: {app_id}", fg="red"))
db.session.close()
@@ -47,7 +48,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
)
try:
if annotations_count > 0:
if annotations_exists:
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
vector.delete()
except Exception: