Refactor: replace count() > 0 check with exists() (#24583)
Co-authored-by: Yongtao Huang <99629139+hyongtao-db@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -3,6 +3,7 @@ import logging
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy import exists, select
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
from controllers.console import api
|
||||
@@ -94,21 +95,18 @@ class ChatMessageListApi(Resource):
|
||||
.all()
|
||||
)
|
||||
|
||||
has_more = False
|
||||
if len(history_messages) == args["limit"]:
|
||||
current_page_first_message = history_messages[-1]
|
||||
rest_count = (
|
||||
db.session.query(Message)
|
||||
.where(
|
||||
|
||||
has_more = db.session.scalar(
|
||||
select(
|
||||
exists().where(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
)
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
|
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, func, text
|
||||
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from configs import dify_config
|
||||
@@ -1553,7 +1553,7 @@ class ApiToken(Base):
|
||||
def generate_api_key(prefix, n):
|
||||
while True:
|
||||
result = prefix + generate_string(n)
|
||||
if db.session.query(ApiToken).where(ApiToken.token == result).count() > 0:
|
||||
if db.session.scalar(select(exists().where(ApiToken.token == result))):
|
||||
continue
|
||||
return result
|
||||
|
||||
|
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, orm
|
||||
from sqlalchemy import DateTime, exists, orm, select
|
||||
|
||||
from core.file.constants import maybe_file_object
|
||||
from core.file.models import File
|
||||
@@ -336,12 +336,13 @@ class Workflow(Base):
|
||||
"""
|
||||
from models.tools import WorkflowToolProvider
|
||||
|
||||
return (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id)
|
||||
.count()
|
||||
> 0
|
||||
stmt = select(
|
||||
exists().where(
|
||||
WorkflowToolProvider.tenant_id == self.tenant_id,
|
||||
WorkflowToolProvider.app_id == self.app_id,
|
||||
)
|
||||
)
|
||||
return db.session.execute(stmt).scalar_one()
|
||||
|
||||
@property
|
||||
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
||||
@@ -921,7 +922,7 @@ def _naive_utc_datetime():
|
||||
|
||||
class WorkflowDraftVariable(Base):
|
||||
"""`WorkflowDraftVariable` record variables and outputs generated during
|
||||
debugging worfklow or chatflow.
|
||||
debugging workflow or chatflow.
|
||||
|
||||
IMPORTANT: This model maintains multiple invariant rules that must be preserved.
|
||||
Do not instantiate this class directly with the constructor.
|
||||
|
@@ -9,7 +9,7 @@ from collections import Counter
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import exists, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@@ -655,10 +655,8 @@ class DatasetService:
|
||||
|
||||
@staticmethod
|
||||
def dataset_use_check(dataset_id) -> bool:
|
||||
count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count()
|
||||
if count > 0:
|
||||
return True
|
||||
return False
|
||||
stmt = select(exists().where(AppDatasetJoin.dataset_id == dataset_id))
|
||||
return db.session.execute(stmt).scalar_one()
|
||||
|
||||
@staticmethod
|
||||
def check_dataset_permission(dataset, user):
|
||||
|
@@ -5,6 +5,7 @@ from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
@@ -190,11 +191,14 @@ class BuiltinToolManageService:
|
||||
# update name if provided
|
||||
if name and name != db_provider.name:
|
||||
# check if the name is already used
|
||||
if (
|
||||
session.query(BuiltinToolProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
|
||||
.count()
|
||||
> 0
|
||||
if session.scalar(
|
||||
select(
|
||||
exists().where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
BuiltinToolProvider.name == name,
|
||||
)
|
||||
)
|
||||
):
|
||||
raise ValueError(f"the credential name '{name}' is already used")
|
||||
|
||||
@@ -246,11 +250,14 @@ class BuiltinToolManageService:
|
||||
)
|
||||
else:
|
||||
# check if the name is already used
|
||||
if (
|
||||
session.query(BuiltinToolProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
|
||||
.count()
|
||||
> 0
|
||||
if session.scalar(
|
||||
select(
|
||||
exists().where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
BuiltinToolProvider.name == name,
|
||||
)
|
||||
)
|
||||
):
|
||||
raise ValueError(f"the credential name '{name}' is already used")
|
||||
|
||||
|
@@ -5,7 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
@@ -87,15 +87,14 @@ class WorkflowService:
|
||||
)
|
||||
|
||||
def is_workflow_exist(self, app_model: App) -> bool:
|
||||
return (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
stmt = select(
|
||||
exists().where(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
)
|
||||
.count()
|
||||
) > 0
|
||||
)
|
||||
return db.session.execute(stmt).scalar_one()
|
||||
|
||||
def get_draft_workflow(self, app_model: App) -> Optional[Workflow]:
|
||||
"""
|
||||
|
@@ -3,6 +3,7 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import exists, select
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
@@ -22,7 +23,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
|
||||
start_at = time.perf_counter()
|
||||
# get app info
|
||||
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
|
||||
annotations_count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).count()
|
||||
annotations_exists = db.session.scalar(select(exists().where(MessageAnnotation.app_id == app_id)))
|
||||
if not app:
|
||||
logger.info(click.style(f"App not found: {app_id}", fg="red"))
|
||||
db.session.close()
|
||||
@@ -47,7 +48,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
|
||||
)
|
||||
|
||||
try:
|
||||
if annotations_count > 0:
|
||||
if annotations_exists:
|
||||
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
vector.delete()
|
||||
except Exception:
|
||||
|
Reference in New Issue
Block a user