diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index f7c83f927..f5e45bcb4 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -334,7 +334,8 @@ class BaseAgentRunner(AppRunner): """ Save agent thought """ - agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first() + stmt = select(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id) + agent_thought = db.session.scalar(stmt) if not agent_thought: raise ValueError("agent thought not found") @@ -492,7 +493,8 @@ class BaseAgentRunner(AppRunner): return result def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: - files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all() + stmt = select(MessageFile).where(MessageFile.message_id == message.id) + files = db.session.scalars(stmt).all() if not files: return UserPromptMessage(content=message.query) if message.app_model_config: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index b78a364a7..5e20e80d1 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -74,6 +74,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): with Session(db.engine, expire_on_commit=False) as session: app_record = session.scalar(select(App).where(App.id == app_config.app_id)) + if not app_record: raise ValueError("App not found") diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 39d6ba39f..d3207365f 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -1,6 +1,8 @@ import logging from typing import cast +from sqlalchemy import select + from core.agent.cot_chat_agent_runner import CotChatAgentRunner from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner from core.agent.entities import AgentEntity @@ -44,8 +46,8 @@ class AgentChatAppRunner(AppRunner): """ app_config = application_generate_entity.app_config app_config = cast(AgentChatAppConfig, app_config) - - app_record = db.session.query(App).where(App.id == app_config.app_id).first() + app_stmt = select(App).where(App.id == app_config.app_id) + app_record = db.session.scalar(app_stmt) if not app_record: raise ValueError("App not found") @@ -182,11 +184,12 @@ class AgentChatAppRunner(AppRunner): if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING - - conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first() + conversation_stmt = select(Conversation).where(Conversation.id == conversation.id) + conversation_result = db.session.scalar(conversation_stmt) if conversation_result is None: raise ValueError("Conversation not found") - message_result = db.session.query(Message).where(Message.id == message.id).first() + msg_stmt = select(Message).where(Message.id == message.id) + message_result = db.session.scalar(msg_stmt) if message_result is None: raise ValueError("Message not found") db.session.close() diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 894d7906d..4385d0f08 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -1,6 +1,8 @@ import logging from typing import cast +from sqlalchemy import select + from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.apps.chat.app_config_manager import ChatAppConfig @@ -42,8 +44,8 @@ class ChatAppRunner(AppRunner): """ app_config = application_generate_entity.app_config app_config = cast(ChatAppConfig, app_config) - - app_record = db.session.query(App).where(App.id == app_config.app_id).first() + stmt = select(App).where(App.id == app_config.app_id) + app_record = db.session.scalar(stmt) if not app_record: raise ValueError("App not found") diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 64dade296..8d2f3d488 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -6,6 +6,7 @@ from typing import Any, Literal, Union, overload from flask import Flask, copy_current_request_context, current_app from pydantic import ValidationError +from sqlalchemy import select from configs import dify_config from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter @@ -248,17 +249,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator): :param invoke_from: invoke from source :param stream: is stream """ - message = ( - db.session.query(Message) - .where( - Message.id == message_id, - Message.app_id == app_model.id, - Message.from_source == ("api" if isinstance(user, EndUser) else "console"), - Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Message.from_account_id == (user.id if isinstance(user, Account) else None), - ) - .first() + stmt = select(Message).where( + Message.id == message_id, + Message.app_id == app_model.id, + Message.from_source == ("api" if isinstance(user, EndUser) else "console"), + Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Message.from_account_id == (user.id if isinstance(user, Account) else None), ) + message = db.session.scalar(stmt) if not message: raise MessageNotExistsError() diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 50d2a0036..d384bff25 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -1,6 +1,8 @@ import logging from typing import cast +from sqlalchemy import select + from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_runner import AppRunner from core.app.apps.completion.app_config_manager import CompletionAppConfig @@ -35,8 +37,8 @@ class CompletionAppRunner(AppRunner): """ app_config = application_generate_entity.app_config app_config = cast(CompletionAppConfig, app_config) - - app_record = db.session.query(App).where(App.id == app_config.app_id).first() + stmt = select(App).where(App.id == app_config.app_id) + app_record = db.session.scalar(stmt) if not app_record: raise ValueError("App not found") diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 1b107e072..92f3b6507 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -86,11 +86,10 @@ class MessageBasedAppGenerator(BaseAppGenerator): def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig: if conversation: - app_model_config = ( - db.session.query(AppModelConfig) - .where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) - .first() + stmt = select(AppModelConfig).where( + AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id ) + app_model_config = db.session.scalar(stmt) if not app_model_config: raise AppModelConfigBrokenError() diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index b82934040..be183e208 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -1,6 +1,8 @@ import logging from typing import Optional +from sqlalchemy import select + from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.datasource.vdb.vector_factory import Vector from extensions.ext_database import db @@ -25,9 +27,8 @@ class AnnotationReplyFeature: :param invoke_from: invoke from :return: """ - annotation_setting = ( - db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first() - ) + stmt = select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id) + annotation_setting = db.session.scalar(stmt) if not annotation_setting: return None diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 956ef4e83..5c19eda21 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -86,7 +86,8 @@ class MessageCycleManager: def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str): with flask_app.app_context(): # get conversation and message - conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first() + stmt = select(Conversation).where(Conversation.id == conversation_id) + conversation = db.session.scalar(stmt) if not conversation: return diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index c55ba5e0f..5cf39d761 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,6 +1,8 @@ import logging from collections.abc import Sequence +from sqlalchemy import select + from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueRetrieverResourcesEvent @@ -49,7 +51,8 @@ class DatasetIndexToolCallbackHandler: for document in documents: if document.metadata is not None: document_id = document.metadata["document_id"] - dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() + dataset_document_stmt = select(DatasetDocument).where(DatasetDocument.id == document_id) + dataset_document = db.session.scalar(dataset_document_stmt) if not dataset_document: _logger.warning( "Expected DatasetDocument record to exist, but none was found, document_id=%s", @@ -57,15 +60,12 @@ class DatasetIndexToolCallbackHandler: ) continue if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: - child_chunk = ( - db.session.query(ChildChunk) - .where( - ChildChunk.index_node_id == document.metadata["doc_id"], - ChildChunk.dataset_id == dataset_document.dataset_id, - ChildChunk.document_id == dataset_document.id, - ) - .first() + child_chunk_stmt = select(ChildChunk).where( + ChildChunk.index_node_id == document.metadata["doc_id"], + ChildChunk.dataset_id == dataset_document.dataset_id, + ChildChunk.document_id == dataset_document.id, ) + child_chunk = db.session.scalar(child_chunk_stmt) if child_chunk: segment = ( db.session.query(DocumentSegment) diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index d81f372d4..2100e7fad 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -1,5 +1,7 @@ from typing import Optional +from sqlalchemy import select + from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor from core.external_data_tool.base import ExternalDataTool from core.helper import encrypter @@ -28,13 +30,11 @@ class ApiExternalDataTool(ExternalDataTool): api_based_extension_id = config.get("api_based_extension_id") if not api_based_extension_id: raise ValueError("api_based_extension_id is required") - # get api_based_extension - api_based_extension = ( - db.session.query(APIBasedExtension) - .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) - .first() + stmt = select(APIBasedExtension).where( + APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id ) + api_based_extension = db.session.scalar(stmt) if not api_based_extension: raise ValueError("api_based_extension_id is invalid") @@ -52,13 +52,11 @@ class ApiExternalDataTool(ExternalDataTool): raise ValueError(f"config is required, config: {self.config}") api_based_extension_id = self.config.get("api_based_extension_id") assert api_based_extension_id is not None, "api_based_extension_id is required" - # get api_based_extension - api_based_extension = ( - db.session.query(APIBasedExtension) - .where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id) - .first() + stmt = select(APIBasedExtension).where( + APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id ) + api_based_extension = db.session.scalar(stmt) if not api_based_extension: raise ValueError( diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index a8e6c261c..4a768618f 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -8,6 +8,7 @@ import uuid from typing import Any, Optional, cast from flask import current_app +from sqlalchemy import select from sqlalchemy.orm.exc import ObjectDeletedError from configs import dify_config @@ -56,13 +57,11 @@ class IndexingRunner: if not dataset: raise ValueError("no dataset found") - # get the process rule - processing_rule = ( - db.session.query(DatasetProcessRule) - .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) - .first() + stmt = select(DatasetProcessRule).where( + DatasetProcessRule.id == dataset_document.dataset_process_rule_id ) + processing_rule = db.session.scalar(stmt) if not processing_rule: raise ValueError("no process rule found") index_type = dataset_document.doc_form @@ -123,11 +122,8 @@ class IndexingRunner: db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete() db.session.commit() # get the process rule - processing_rule = ( - db.session.query(DatasetProcessRule) - .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) - .first() - ) + stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + processing_rule = db.session.scalar(stmt) if not processing_rule: raise ValueError("no process rule found") @@ -208,7 +204,6 @@ class IndexingRunner: child_documents.append(child_document) document.children = child_documents documents.append(document) - # build index index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -310,7 +305,8 @@ class IndexingRunner: # delete image files and related db records image_upload_file_ids = get_image_upload_file_ids(document.page_content) for upload_file_id in image_upload_file_ids: - image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() + stmt = select(UploadFile).where(UploadFile.id == upload_file_id) + image_file = db.session.scalar(stmt) if image_file is None: continue try: @@ -339,10 +335,8 @@ class IndexingRunner: if dataset_document.data_source_type == "upload_file": if not data_source_info or "upload_file_id" not in data_source_info: raise ValueError("no upload file found") - - file_detail = ( - db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none() - ) + stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]) + file_detail = db.session.scalars(stmt).one_or_none() if file_detail: extract_setting = ExtractSetting( diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 36f8c606b..17050fcad 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -110,9 +110,9 @@ class TokenBufferMemory: else: message_limit = 500 - stmt = stmt.limit(message_limit) + msg_limit_stmt = stmt.limit(message_limit) - messages = db.session.scalars(stmt).all() + messages = db.session.scalars(msg_limit_stmt).all() # instead of all messages from the conversation, we only need to extract messages # that belong to the thread of last message diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index af51b72cd..06d5c02bb 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -1,6 +1,7 @@ from typing import Optional from pydantic import BaseModel, Field +from sqlalchemy import select from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor from core.helper.encrypter import decrypt_token @@ -87,10 +88,9 @@ class ApiModeration(Moderation): @staticmethod def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]: - extension = ( - db.session.query(APIBasedExtension) - .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) - .first() + stmt = select(APIBasedExtension).where( + APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id ) + extension = db.session.scalar(stmt) return extension diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index 1ddfc4cc2..c66105063 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -5,6 +5,7 @@ from typing import Optional from urllib.parse import urljoin from opentelemetry.trace import Link, Status, StatusCode +from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from core.ops.aliyun_trace.data_exporter.traceclient import ( @@ -263,15 +264,15 @@ class AliyunDataTrace(BaseTraceInstance): app_id = trace_info.metadata.get("app_id") if not app_id: raise ValueError("No app_id found in trace_info metadata") - - app = session.query(App).where(App.id == app_id).first() + app_stmt = select(App).where(App.id == app_id) + app = session.scalar(app_stmt) if not app: raise ValueError(f"App with id {app_id} not found") if not app.created_by: raise ValueError(f"App with id {app_id} has no creator (created_by is None)") - - service_account = session.query(Account).where(Account.id == app.created_by).first() + account_stmt = select(Account).where(Account.id == app.created_by) + service_account = session.scalar(account_stmt) if not service_account: raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") current_tenant = ( diff --git a/api/core/ops/base_trace_instance.py b/api/core/ops/base_trace_instance.py index f8e428daf..04b46d67a 100644 --- a/api/core/ops/base_trace_instance.py +++ b/api/core/ops/base_trace_instance.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod +from sqlalchemy import select from sqlalchemy.orm import Session from core.ops.entities.config_entity import BaseTracingConfig @@ -44,14 +45,15 @@ class BaseTraceInstance(ABC): """ with Session(db.engine, expire_on_commit=False) as session: # Get the app to find its creator - app = session.query(App).where(App.id == app_id).first() + app_stmt = select(App).where(App.id == app_id) + app = session.scalar(app_stmt) if not app: raise ValueError(f"App with id {app_id} not found") if not app.created_by: raise ValueError(f"App with id {app_id} has no creator (created_by is None)") - - service_account = session.query(Account).where(Account.id == app.created_by).first() + account_stmt = select(Account).where(Account.id == app.created_by) + service_account = session.scalar(account_stmt) if not service_account: raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index aa2f17553..fa9629947 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -226,9 +226,9 @@ class OpsTraceManager: if not trace_config_data: return None - # decrypt_token - app = db.session.query(App).where(App.id == app_id).first() + stmt = select(App).where(App.id == app_id) + app = db.session.scalar(stmt) if not app: raise ValueError("App not found") @@ -295,20 +295,19 @@ class OpsTraceManager: @classmethod def get_app_config_through_message_id(cls, message_id: str): app_model_config = None - message_data = db.session.query(Message).where(Message.id == message_id).first() + message_stmt = select(Message).where(Message.id == message_id) + message_data = db.session.scalar(message_stmt) if not message_data: return None conversation_id = message_data.conversation_id - conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first() + conversation_stmt = select(Conversation).where(Conversation.id == conversation_id) + conversation_data = db.session.scalar(conversation_stmt) if not conversation_data: return None if conversation_data.app_model_config_id: - app_model_config = ( - db.session.query(AppModelConfig) - .where(AppModelConfig.id == conversation_data.app_model_config_id) - .first() - ) + config_stmt = select(AppModelConfig).where(AppModelConfig.id == conversation_data.app_model_config_id) + app_model_config = db.session.scalar(config_stmt) elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs: app_model_config = conversation_data.override_model_configs diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index cf62dc6ab..a79964644 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -1,6 +1,8 @@ from collections.abc import Generator, Mapping from typing import Optional, Union +from sqlalchemy import select + from controllers.service_api.wraps import create_or_update_end_user_for_user_id from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator @@ -192,10 +194,11 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): """ get the user by user id """ - - user = db.session.query(EndUser).where(EndUser.id == user_id).first() + stmt = select(EndUser).where(EndUser.id == user_id) + user = db.session.scalar(stmt) if not user: - user = db.session.query(Account).where(Account.id == user_id).first() + stmt = select(Account).where(Account.id == user_id) + user = db.session.scalar(stmt) if not user: raise ValueError("user not found") diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index cad0de647..04996442c 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -276,15 +276,11 @@ class ProviderManager: :param model_type: model type :return: """ - # Get the corresponding TenantDefaultModel record - default_model = ( - db.session.query(TenantDefaultModel) - .where( - TenantDefaultModel.tenant_id == tenant_id, - TenantDefaultModel.model_type == model_type.to_origin_model_type(), - ) - .first() + stmt = select(TenantDefaultModel).where( + TenantDefaultModel.tenant_id == tenant_id, + TenantDefaultModel.model_type == model_type.to_origin_model_type(), ) + default_model = db.session.scalar(stmt) # If it does not exist, get the first available provider model from get_configurations # and update the TenantDefaultModel record @@ -367,16 +363,11 @@ class ProviderManager: model_names = [model.model for model in available_models] if model not in model_names: raise ValueError(f"Model {model} does not exist.") - - # Get the list of available models from get_configurations and check if it is LLM - default_model = ( - db.session.query(TenantDefaultModel) - .where( - TenantDefaultModel.tenant_id == tenant_id, - TenantDefaultModel.model_type == model_type.to_origin_model_type(), - ) - .first() + stmt = select(TenantDefaultModel).where( + TenantDefaultModel.tenant_id == tenant_id, + TenantDefaultModel.model_type == model_type.to_origin_model_type(), ) + default_model = db.session.scalar(stmt) # create or update TenantDefaultModel record if default_model: @@ -598,16 +589,13 @@ class ProviderManager: provider_name_to_provider_records_dict[provider_name].append(new_provider_record) except IntegrityError: db.session.rollback() - existed_provider_record = ( - db.session.query(Provider) - .where( - Provider.tenant_id == tenant_id, - Provider.provider_name == ModelProviderID(provider_name).provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == ProviderQuotaType.TRIAL.value, - ) - .first() + stmt = select(Provider).where( + Provider.tenant_id == tenant_id, + Provider.provider_name == ModelProviderID(provider_name).provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == ProviderQuotaType.TRIAL.value, ) + existed_provider_record = db.session.scalar(stmt) if not existed_provider_record: continue diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index c98306ea4..5fb6f9fcc 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -3,6 +3,7 @@ from typing import Any, Optional import orjson from pydantic import BaseModel +from sqlalchemy import select from configs import dify_config from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler @@ -211,11 +212,10 @@ class Jieba(BaseKeyword): return sorted_chunk_indices[:k] def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]): - document_segment = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id) - .first() + stmt = select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id ) + document_segment = db.session.scalar(stmt) if document_segment: document_segment.keywords = keywords db.session.add(document_segment) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 2912a48d9..fefd42f84 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -3,6 +3,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional from flask import Flask, current_app +from sqlalchemy import select from sqlalchemy.orm import Session, load_only from configs import dify_config @@ -127,7 +128,8 @@ class RetrievalService: external_retrieval_model: Optional[dict] = None, metadata_filtering_conditions: Optional[dict] = None, ): - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() + stmt = select(Dataset).where(Dataset.id == dataset_id) + dataset = db.session.scalar(stmt) if not dataset: return [] metadata_condition = ( @@ -316,10 +318,8 @@ class RetrievalService: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: # Handle parent-child documents child_index_node_id = document.metadata.get("doc_id") - - child_chunk = ( - db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first() - ) + child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id) + child_chunk = db.session.scalar(child_chunk_stmt) if not child_chunk: continue @@ -378,17 +378,13 @@ class RetrievalService: index_node_id = document.metadata.get("doc_id") if not index_node_id: continue - - segment = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.dataset_id == dataset_document.dataset_id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.index_node_id == index_node_id, - ) - .first() + document_segment_stmt = select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.index_node_id == index_node_id, ) + segment = db.session.scalar(document_segment_stmt) if not segment: continue diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index fc3c2fc63..e55c06e66 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -18,6 +18,7 @@ from qdrant_client.http.models import ( TokenizerType, ) from qdrant_client.local.qdrant_local import QdrantLocal +from sqlalchemy import select from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -445,11 +446,8 @@ class QdrantVector(BaseVector): class QdrantVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector: if dataset.collection_binding_id: - dataset_collection_binding = ( - db.session.query(DatasetCollectionBinding) - .where(DatasetCollectionBinding.id == dataset.collection_binding_id) - .one_or_none() - ) + stmt = select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == dataset.collection_binding_id) + dataset_collection_binding = db.session.scalars(stmt).one_or_none() if dataset_collection_binding: collection_name = dataset_collection_binding.collection_name else: diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index feda11b7f..be24f5a56 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -20,6 +20,7 @@ from qdrant_client.http.models import ( ) from qdrant_client.local.qdrant_local import QdrantLocal from requests.auth import HTTPDigestAuth +from sqlalchemy import select from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -416,16 +417,12 @@ class TidbOnQdrantVector(BaseVector): class TidbOnQdrantVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector: - tidb_auth_binding = ( - db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() - ) + stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id) + tidb_auth_binding = db.session.scalars(stmt).one_or_none() if not tidb_auth_binding: with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): - tidb_auth_binding = ( - db.session.query(TidbAuthBinding) - .where(TidbAuthBinding.tenant_id == dataset.tenant_id) - .one_or_none() - ) + stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id) + tidb_auth_binding = db.session.scalars(stmt).one_or_none() if tidb_auth_binding: TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index eef03ce41..661a8f37a 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -3,6 +3,8 @@ import time from abc import ABC, abstractmethod from typing import Any, Optional +from sqlalchemy import select + from configs import dify_config from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType @@ -45,11 +47,10 @@ class Vector: vector_type = self._dataset.index_struct_dict["type"] else: if dify_config.VECTOR_STORE_WHITELIST_ENABLE: - whitelist = ( - db.session.query(Whitelist) - .where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db") - .one_or_none() + stmt = select(Whitelist).where( + Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db" ) + whitelist = db.session.scalars(stmt).one_or_none() if whitelist: vector_type = VectorType.TIDB_ON_QDRANT diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index f8da3657f..717cfe8f5 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from typing import Any, Optional -from sqlalchemy import func +from sqlalchemy import func, select from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType @@ -41,9 +41,8 @@ class DatasetDocumentStore: @property def docs(self) -> dict[str, Document]: - document_segments = ( - db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id).all() - ) + stmt = select(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id) + document_segments = db.session.scalars(stmt).all() output = {} for document_segment in document_segments: @@ -228,10 +227,9 @@ class DatasetDocumentStore: return data def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]: - document_segment = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id) - .first() + stmt = select(DocumentSegment).where( + DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id ) + document_segment = db.session.scalar(stmt) return document_segment diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 3d4b898c9..206b2bb92 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -4,6 +4,7 @@ import operator from typing import Any, Optional, cast import requests +from sqlalchemy import select from configs import dify_config from core.rag.extractor.extractor_base import BaseExtractor @@ -367,18 +368,13 @@ class NotionExtractor(BaseExtractor): @classmethod def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: - data_source_binding = ( - db.session.query(DataSourceOauthBinding) - .where( - db.and_( - DataSourceOauthBinding.tenant_id == tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"', - ) - ) - .first() + stmt = select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.disabled == False, + DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"', ) + data_source_binding = db.session.scalar(stmt) if not data_source_binding: raise Exception( diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 49c72b4ba..11010c9d6 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping from typing import Any, Optional, Union, cast from flask import Flask, current_app -from sqlalchemy import Float, and_, or_, text +from sqlalchemy import Float, and_, or_, select, text from sqlalchemy import cast as sqlalchemy_cast from sqlalchemy.orm import Session @@ -135,7 +135,8 @@ class DatasetRetrieval: available_datasets = [] for dataset_id in dataset_ids: # get dataset from dataset id - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id) + dataset = db.session.scalar(dataset_stmt) # pass if dataset is not available if not dataset: @@ -240,15 +241,12 @@ class DatasetRetrieval: for record in records: segment = record.segment dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() - document = ( - db.session.query(DatasetDocument) - .where( - DatasetDocument.id == segment.document_id, - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ) - .first() + dataset_document_stmt = select(DatasetDocument).where( + DatasetDocument.id == segment.document_id, + DatasetDocument.enabled == True, + DatasetDocument.archived == False, ) + document = db.session.scalar(dataset_document_stmt) if dataset and document: source = RetrievalSourceMetadata( dataset_id=dataset.id, @@ -327,7 +325,8 @@ class DatasetRetrieval: if dataset_id: # get retrieval model config - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) + dataset = db.session.scalar(dataset_stmt) if dataset: results = [] if dataset.provider == "external": @@ -514,22 +513,18 @@ class DatasetRetrieval: dify_documents = [document for document in documents if document.provider == "dify"] for document in dify_documents: if document.metadata is not None: - dataset_document = ( - db.session.query(DatasetDocument) - .where(DatasetDocument.id == document.metadata["document_id"]) - .first() + dataset_document_stmt = select(DatasetDocument).where( + DatasetDocument.id == document.metadata["document_id"] ) + dataset_document = db.session.scalar(dataset_document_stmt) if dataset_document: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: - child_chunk = ( - db.session.query(ChildChunk) - .where( - ChildChunk.index_node_id == document.metadata["doc_id"], - ChildChunk.dataset_id == dataset_document.dataset_id, - ChildChunk.document_id == dataset_document.id, - ) - .first() + child_chunk_stmt = select(ChildChunk).where( + ChildChunk.index_node_id == document.metadata["doc_id"], + ChildChunk.dataset_id == dataset_document.dataset_id, + ChildChunk.document_id == dataset_document.id, ) + child_chunk = db.session.scalar(child_chunk_stmt) if child_chunk: segment = ( db.session.query(DocumentSegment) @@ -600,7 +595,8 @@ class DatasetRetrieval: ): with flask_app.app_context(): with Session(db.engine) as session: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) + dataset = db.session.scalar(dataset_stmt) if not dataset: return [] @@ -685,7 +681,8 @@ class DatasetRetrieval: available_datasets = [] for dataset_id in dataset_ids: # get dataset from dataset id - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id) + dataset = db.session.scalar(dataset_stmt) # pass if dataset is not available if not dataset: @@ -958,7 +955,8 @@ class DatasetRetrieval: self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig ) -> Optional[list[dict[str, Any]]]: # get all metadata field - metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all() + metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)) + metadata_fields = db.session.scalars(metadata_stmt).all() all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] # get metadata model config if metadata_model_config is None: diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index cdfefbadb..90c09a444 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -1,3 +1,5 @@ +from sqlalchemy import select + from core.tools.__base.tool_provider import ToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.custom_tool.provider import ApiToolProviderController @@ -54,17 +56,13 @@ class ToolLabelManager: return controller.tool_labels else: raise ValueError("Unsupported tool type") - - labels = ( - db.session.query(ToolLabelBinding.label_name) - .where( - ToolLabelBinding.tool_id == provider_id, - ToolLabelBinding.tool_type == controller.provider_type.value, - ) - .all() + stmt = select(ToolLabelBinding.label_name).where( + ToolLabelBinding.tool_id == provider_id, + ToolLabelBinding.tool_type == controller.provider_type.value, ) + labels = db.session.scalars(stmt).all() - return [label.label_name for label in labels] + return list(labels) @classmethod def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index b338a779a..eefeac934 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast import sqlalchemy as sa from pydantic import TypeAdapter +from sqlalchemy import select from sqlalchemy.orm import Session from yarl import URL @@ -198,14 +199,11 @@ class ToolManager: # get specific credentials if is_valid_uuid(credential_id): try: - builtin_provider = ( - db.session.query(BuiltinToolProvider) - .where( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.id == credential_id, - ) - .first() + builtin_provider_stmt = select(BuiltinToolProvider).where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, ) + builtin_provider = db.session.scalar(builtin_provider_stmt) except Exception as e: builtin_provider = None logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True) @@ -317,11 +315,10 @@ class ToolManager: ), ) elif provider_type == ToolProviderType.WORKFLOW: - workflow_provider = ( - db.session.query(WorkflowToolProvider) - .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) - .first() + workflow_provider_stmt = select(WorkflowToolProvider).where( + WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id ) + workflow_provider = db.session.scalar(workflow_provider_stmt) if workflow_provider is None: raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 56c6a9fbe..75c0c6738 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -3,6 +3,7 @@ from typing import Any from flask import Flask, current_app from pydantic import BaseModel, Field +from sqlalchemy import select from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.model_manager import ModelManager @@ -85,17 +86,14 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): document_context_list = [] index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] - segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.dataset_id.in_(self.dataset_ids), - DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids), - ) - .all() + document_segment_stmt = select(DocumentSegment).where( + DocumentSegment.dataset_id.in_(self.dataset_ids), + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids), ) + segments = db.session.scalars(document_segment_stmt).all() if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} @@ -112,15 +110,12 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): resource_number = 1 for segment in sorted_segments: dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() - document = ( - db.session.query(Document) - .where( - Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ) - .first() + document_stmt = select(Document).where( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, ) + document = db.session.scalar(document_stmt) if dataset and document: source = RetrievalSourceMetadata( position=resource_number, @@ -162,9 +157,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): hit_callbacks: list[DatasetIndexToolCallbackHandler], ): with flask_app.app_context(): - dataset = ( - db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() - ) + stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id) + dataset = db.session.scalar(stmt) if not dataset: return [] diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index f7689d770..b536c5a25 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -1,6 +1,7 @@ from typing import Any, Optional, cast from pydantic import BaseModel, Field +from sqlalchemy import select from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig from core.rag.datasource.retrieval_service import RetrievalService @@ -56,9 +57,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): ) def _run(self, query: str) -> str: - dataset = ( - db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first() - ) + dataset_stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id) + dataset = db.session.scalar(dataset_stmt) if not dataset: return "" @@ -188,15 +188,12 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): for record in records: segment = record.segment dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() - document = ( - db.session.query(DatasetDocument) # type: ignore - .where( - DatasetDocument.id == segment.document_id, - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ) - .first() + dataset_document_stmt = select(DatasetDocument).where( + DatasetDocument.id == segment.document_id, + DatasetDocument.enabled == True, + DatasetDocument.archived == False, ) + document = db.session.scalar(dataset_document_stmt) # type: ignore if dataset and document: source = RetrievalSourceMetadata( dataset_id=dataset.id, diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index ea219af68..9bcc63952 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -3,6 +3,8 @@ import logging from collections.abc import Generator from typing import Any, Optional +from sqlalchemy import select + from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime @@ -136,7 +138,8 @@ class WorkflowTool(Tool): .first() ) else: - workflow = db.session.query(Workflow).where(Workflow.app_id == app_id, Workflow.version == version).first() + stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version) + workflow = db.session.scalar(stmt) if not workflow: raise ValueError("workflow not found or not published") @@ -147,7 +150,8 @@ class WorkflowTool(Tool): """ get the app by app id """ - app = db.session.query(App).where(App.id == app_id).first() + stmt = select(App).where(App.id == app_id) + app = db.session.scalar(stmt) if not app: raise ValueError("app not found") diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index a44f15f87..f1f2fcdbc 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -6,7 +6,7 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional, cast -from sqlalchemy import Float, and_, func, or_, text +from sqlalchemy import Float, and_, func, or_, select, text from sqlalchemy import cast as sqlalchemy_cast from sqlalchemy.orm import sessionmaker @@ -367,15 +367,12 @@ class KnowledgeRetrievalNode(BaseNode): for record in records: segment = record.segment dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore - document = ( - db.session.query(Document) - .where( - Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ) - .first() + stmt = select(Document).where( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, ) + document = db.session.scalar(stmt) if dataset and document: source = { "metadata": { @@ -514,7 +511,8 @@ class KnowledgeRetrievalNode(BaseNode): self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData ) -> list[dict[str, Any]]: # get all metadata field - metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all() + stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)) + metadata_fields = db.session.scalars(stmt).all() all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] if node_data.metadata_model_config is None: raise ValueError("metadata_model_config is required")