Refactor: replace count() > 0 check with exists() (#24583)
Co-authored-by: Yongtao Huang <99629139+hyongtao-db@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -3,6 +3,7 @@ import logging
|
|||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx.inputs import int_range
|
||||||
|
from sqlalchemy import exists, select
|
||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
@@ -94,21 +95,18 @@ class ChatMessageListApi(Resource):
|
|||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
has_more = False
|
|
||||||
if len(history_messages) == args["limit"]:
|
if len(history_messages) == args["limit"]:
|
||||||
current_page_first_message = history_messages[-1]
|
current_page_first_message = history_messages[-1]
|
||||||
rest_count = (
|
|
||||||
db.session.query(Message)
|
has_more = db.session.scalar(
|
||||||
.where(
|
select(
|
||||||
|
exists().where(
|
||||||
Message.conversation_id == conversation.id,
|
Message.conversation_id == conversation.id,
|
||||||
Message.created_at < current_page_first_message.created_at,
|
Message.created_at < current_page_first_message.created_at,
|
||||||
Message.id != current_page_first_message.id,
|
Message.id != current_page_first_message.id,
|
||||||
)
|
)
|
||||||
.count()
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
if rest_count > 0:
|
|
||||||
has_more = True
|
|
||||||
|
|
||||||
history_messages = list(reversed(history_messages))
|
history_messages = list(reversed(history_messages))
|
||||||
|
|
||||||
|
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import UserMixin
|
from flask_login import UserMixin
|
||||||
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, func, text
|
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
|
||||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@@ -1553,7 +1553,7 @@ class ApiToken(Base):
|
|||||||
def generate_api_key(prefix, n):
|
def generate_api_key(prefix, n):
|
||||||
while True:
|
while True:
|
||||||
result = prefix + generate_string(n)
|
result = prefix + generate_string(n)
|
||||||
if db.session.query(ApiToken).where(ApiToken.token == result).count() > 0:
|
if db.session.scalar(select(exists().where(ApiToken.token == result))):
|
||||||
continue
|
continue
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy import DateTime, orm
|
from sqlalchemy import DateTime, exists, orm, select
|
||||||
|
|
||||||
from core.file.constants import maybe_file_object
|
from core.file.constants import maybe_file_object
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
@@ -336,12 +336,13 @@ class Workflow(Base):
|
|||||||
"""
|
"""
|
||||||
from models.tools import WorkflowToolProvider
|
from models.tools import WorkflowToolProvider
|
||||||
|
|
||||||
return (
|
stmt = select(
|
||||||
db.session.query(WorkflowToolProvider)
|
exists().where(
|
||||||
.where(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id)
|
WorkflowToolProvider.tenant_id == self.tenant_id,
|
||||||
.count()
|
WorkflowToolProvider.app_id == self.app_id,
|
||||||
> 0
|
)
|
||||||
)
|
)
|
||||||
|
return db.session.execute(stmt).scalar_one()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
||||||
@@ -921,7 +922,7 @@ def _naive_utc_datetime():
|
|||||||
|
|
||||||
class WorkflowDraftVariable(Base):
|
class WorkflowDraftVariable(Base):
|
||||||
"""`WorkflowDraftVariable` record variables and outputs generated during
|
"""`WorkflowDraftVariable` record variables and outputs generated during
|
||||||
debugging worfklow or chatflow.
|
debugging workflow or chatflow.
|
||||||
|
|
||||||
IMPORTANT: This model maintains multiple invariant rules that must be preserved.
|
IMPORTANT: This model maintains multiple invariant rules that must be preserved.
|
||||||
Do not instantiate this class directly with the constructor.
|
Do not instantiate this class directly with the constructor.
|
||||||
|
@@ -9,7 +9,7 @@ from collections import Counter
|
|||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from sqlalchemy import func, select
|
from sqlalchemy import exists, func, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
@@ -655,10 +655,8 @@ class DatasetService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def dataset_use_check(dataset_id) -> bool:
|
def dataset_use_check(dataset_id) -> bool:
|
||||||
count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count()
|
stmt = select(exists().where(AppDatasetJoin.dataset_id == dataset_id))
|
||||||
if count > 0:
|
return db.session.execute(stmt).scalar_one()
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_dataset_permission(dataset, user):
|
def check_dataset_permission(dataset, user):
|
||||||
|
@@ -5,6 +5,7 @@ from collections.abc import Mapping
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from sqlalchemy import exists, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@@ -190,11 +191,14 @@ class BuiltinToolManageService:
|
|||||||
# update name if provided
|
# update name if provided
|
||||||
if name and name != db_provider.name:
|
if name and name != db_provider.name:
|
||||||
# check if the name is already used
|
# check if the name is already used
|
||||||
if (
|
if session.scalar(
|
||||||
session.query(BuiltinToolProvider)
|
select(
|
||||||
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
|
exists().where(
|
||||||
.count()
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
> 0
|
BuiltinToolProvider.provider == provider,
|
||||||
|
BuiltinToolProvider.name == name,
|
||||||
|
)
|
||||||
|
)
|
||||||
):
|
):
|
||||||
raise ValueError(f"the credential name '{name}' is already used")
|
raise ValueError(f"the credential name '{name}' is already used")
|
||||||
|
|
||||||
@@ -246,11 +250,14 @@ class BuiltinToolManageService:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# check if the name is already used
|
# check if the name is already used
|
||||||
if (
|
if session.scalar(
|
||||||
session.query(BuiltinToolProvider)
|
select(
|
||||||
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
|
exists().where(
|
||||||
.count()
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
> 0
|
BuiltinToolProvider.provider == provider,
|
||||||
|
BuiltinToolProvider.name == name,
|
||||||
|
)
|
||||||
|
)
|
||||||
):
|
):
|
||||||
raise ValueError(f"the credential name '{name}' is already used")
|
raise ValueError(f"the credential name '{name}' is already used")
|
||||||
|
|
||||||
|
@@ -5,7 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
|
|||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import exists, select
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from core.app.app_config.entities import VariableEntityType
|
from core.app.app_config.entities import VariableEntityType
|
||||||
@@ -87,15 +87,14 @@ class WorkflowService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def is_workflow_exist(self, app_model: App) -> bool:
|
def is_workflow_exist(self, app_model: App) -> bool:
|
||||||
return (
|
stmt = select(
|
||||||
db.session.query(Workflow)
|
exists().where(
|
||||||
.where(
|
|
||||||
Workflow.tenant_id == app_model.tenant_id,
|
Workflow.tenant_id == app_model.tenant_id,
|
||||||
Workflow.app_id == app_model.id,
|
Workflow.app_id == app_model.id,
|
||||||
Workflow.version == Workflow.VERSION_DRAFT,
|
Workflow.version == Workflow.VERSION_DRAFT,
|
||||||
)
|
)
|
||||||
.count()
|
)
|
||||||
) > 0
|
return db.session.execute(stmt).scalar_one()
|
||||||
|
|
||||||
def get_draft_workflow(self, app_model: App) -> Optional[Workflow]:
|
def get_draft_workflow(self, app_model: App) -> Optional[Workflow]:
|
||||||
"""
|
"""
|
||||||
|
@@ -3,6 +3,7 @@ import time
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
|
from sqlalchemy import exists, select
|
||||||
|
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@@ -22,7 +23,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
|
|||||||
start_at = time.perf_counter()
|
start_at = time.perf_counter()
|
||||||
# get app info
|
# get app info
|
||||||
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
|
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
|
||||||
annotations_count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).count()
|
annotations_exists = db.session.scalar(select(exists().where(MessageAnnotation.app_id == app_id)))
|
||||||
if not app:
|
if not app:
|
||||||
logger.info(click.style(f"App not found: {app_id}", fg="red"))
|
logger.info(click.style(f"App not found: {app_id}", fg="red"))
|
||||||
db.session.close()
|
db.session.close()
|
||||||
@@ -47,7 +48,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if annotations_count > 0:
|
if annotations_exists:
|
||||||
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||||
vector.delete()
|
vector.delete()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
Reference in New Issue
Block a user