diff --git a/api/commands.py b/api/commands.py index 9f933a378..c2e62ec26 100644 --- a/api/commands.py +++ b/api/commands.py @@ -50,7 +50,7 @@ def reset_password(email, new_password, password_confirm): click.echo(click.style("Passwords do not match.", fg="red")) return - account = db.session.query(Account).filter(Account.email == email).one_or_none() + account = db.session.query(Account).where(Account.email == email).one_or_none() if not account: click.echo(click.style("Account not found for email: {}".format(email), fg="red")) @@ -89,7 +89,7 @@ def reset_email(email, new_email, email_confirm): click.echo(click.style("New emails do not match.", fg="red")) return - account = db.session.query(Account).filter(Account.email == email).one_or_none() + account = db.session.query(Account).where(Account.email == email).one_or_none() if not account: click.echo(click.style("Account not found for email: {}".format(email), fg="red")) @@ -136,8 +136,8 @@ def reset_encrypt_key_pair(): tenant.encrypt_public_key = generate_key_pair(tenant.id) - db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() - db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete() + db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() + db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete() db.session.commit() click.echo( @@ -172,7 +172,7 @@ def migrate_annotation_vector_database(): per_page = 50 apps = ( db.session.query(App) - .filter(App.status == "normal") + .where(App.status == "normal") .order_by(App.created_at.desc()) .limit(per_page) .offset((page - 1) * per_page) @@ -192,7 +192,7 @@ def migrate_annotation_vector_database(): try: click.echo("Creating app annotation index: {}".format(app.id)) app_annotation_setting = ( - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first() + db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first() ) if not app_annotation_setting: @@ -202,13 +202,13 @@ def migrate_annotation_vector_database(): # get dataset_collection_binding info dataset_collection_binding = ( db.session.query(DatasetCollectionBinding) - .filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) + .where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) .first() ) if not dataset_collection_binding: click.echo("App annotation collection binding not found: {}".format(app.id)) continue - annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all() + annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all() dataset = Dataset( id=app.id, tenant_id=app.tenant_id, @@ -305,7 +305,7 @@ def migrate_knowledge_vector_database(): while True: try: stmt = ( - select(Dataset).filter(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) + select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) ) datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) @@ -332,7 +332,7 @@ def migrate_knowledge_vector_database(): if dataset.collection_binding_id: dataset_collection_binding = ( db.session.query(DatasetCollectionBinding) - .filter(DatasetCollectionBinding.id == dataset.collection_binding_id) + .where(DatasetCollectionBinding.id == dataset.collection_binding_id) .one_or_none() ) if dataset_collection_binding: @@ -367,7 +367,7 @@ def migrate_knowledge_vector_database(): dataset_documents = ( db.session.query(DatasetDocument) - .filter( + .where( DatasetDocument.dataset_id == dataset.id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, @@ -381,7 +381,7 @@ def migrate_knowledge_vector_database(): for dataset_document in dataset_documents: segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.status == "completed", DocumentSegment.enabled == True, @@ -468,7 +468,7 @@ def convert_to_agent_apps(): app_id = str(i.id) if app_id not in proceeded_app_ids: proceeded_app_ids.append(app_id) - app = db.session.query(App).filter(App.id == app_id).first() + app = db.session.query(App).where(App.id == app_id).first() if app is not None: apps.append(app) @@ -483,7 +483,7 @@ def convert_to_agent_apps(): db.session.commit() # update conversation mode to agent - db.session.query(Conversation).filter(Conversation.app_id == app.id).update( + db.session.query(Conversation).where(Conversation.app_id == app.id).update( {Conversation.mode: AppMode.AGENT_CHAT.value} ) @@ -560,7 +560,7 @@ def old_metadata_migration(): try: stmt = ( select(DatasetDocument) - .filter(DatasetDocument.doc_metadata.is_not(None)) + .where(DatasetDocument.doc_metadata.is_not(None)) .order_by(DatasetDocument.created_at.desc()) ) documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) @@ -578,7 +578,7 @@ def old_metadata_migration(): else: dataset_metadata = ( db.session.query(DatasetMetadata) - .filter(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key) + .where(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key) .first() ) if not dataset_metadata: @@ -602,7 +602,7 @@ def old_metadata_migration(): else: dataset_metadata_binding = ( db.session.query(DatasetMetadataBinding) # type: ignore - .filter( + .where( DatasetMetadataBinding.dataset_id == document.dataset_id, DatasetMetadataBinding.document_id == document.id, DatasetMetadataBinding.metadata_id == dataset_metadata.id, @@ -717,7 +717,7 @@ where sites.id is null limit 1000""" continue try: - app = db.session.query(App).filter(App.id == app_id).first() + app = db.session.query(App).where(App.id == app_id).first() if not app: print(f"App {app_id} not found") continue diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index f5257fae7..8a55197fb 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -56,7 +56,7 @@ class InsertExploreAppListApi(Resource): parser.add_argument("position", type=int, required=True, nullable=False, location="json") args = parser.parse_args() - app = db.session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none() + app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none() if not app: raise NotFound(f"App '{args['app_id']}' is not found") @@ -74,7 +74,7 @@ class InsertExploreAppListApi(Resource): with Session(db.engine) as session: recommended_app = session.execute( - select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]) + select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]) ).scalar_one_or_none() if not recommended_app: @@ -117,21 +117,21 @@ class InsertExploreAppApi(Resource): def delete(self, app_id): with Session(db.engine) as session: recommended_app = session.execute( - select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id)) + select(RecommendedApp).where(RecommendedApp.app_id == str(app_id)) ).scalar_one_or_none() if not recommended_app: return {"result": "success"}, 204 with Session(db.engine) as session: - app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none() + app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none() if app: app.is_public = False with Session(db.engine) as session: installed_apps = session.execute( - select(InstalledApp).filter( + select(InstalledApp).where( InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id, ) diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 47c93a15c..d7500c415 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -61,7 +61,7 @@ class BaseApiKeyListResource(Resource): _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) keys = ( db.session.query(ApiToken) - .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) + .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) .all() ) return {"items": keys} @@ -76,7 +76,7 @@ class BaseApiKeyListResource(Resource): current_key_count = ( db.session.query(ApiToken) - .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) + .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) .count() ) @@ -117,7 +117,7 @@ class BaseApiKeyResource(Resource): key = ( db.session.query(ApiToken) - .filter( + .where( getattr(ApiToken, self.resource_id_field) == resource_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id, @@ -128,7 +128,7 @@ class BaseApiKeyResource(Resource): if key is None: flask_restful.abort(404, message="API key not found") - db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() + db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.commit() return {"result": "success"}, 204 diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 4eef9fed4..b5b6d1f75 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -49,7 +49,7 @@ class CompletionConversationApi(Resource): query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion") if args["keyword"]: - query = query.join(Message, Message.conversation_id == Conversation.id).filter( + query = query.join(Message, Message.conversation_id == Conversation.id).where( or_( Message.query.ilike("%{}%".format(args["keyword"])), Message.answer.ilike("%{}%".format(args["keyword"])), @@ -121,7 +121,7 @@ class CompletionConversationDetailApi(Resource): conversation = ( db.session.query(Conversation) - .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) + .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) .first() ) @@ -181,7 +181,7 @@ class ChatConversationApi(Resource): Message.conversation_id == Conversation.id, ) .join(subquery, subquery.c.conversation_id == Conversation.id) - .filter( + .where( or_( Message.query.ilike(keyword_filter), Message.answer.ilike(keyword_filter), @@ -286,7 +286,7 @@ class ChatConversationDetailApi(Resource): conversation = ( db.session.query(Conversation) - .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) + .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) .first() ) @@ -308,7 +308,7 @@ api.add_resource(ChatConversationDetailApi, "/apps//chat-conversati def _get_conversation(app_model, conversation_id): conversation = ( db.session.query(Conversation) - .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) + .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) .first() ) diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 503393f26..2344fd5ac 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -26,7 +26,7 @@ class AppMCPServerController(Resource): @get_app_model @marshal_with(app_server_fields) def get(self, app_model): - server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == app_model.id).first() + server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first() return server @setup_required @@ -73,7 +73,7 @@ class AppMCPServerController(Resource): parser.add_argument("parameters", type=dict, required=True, location="json") parser.add_argument("status", type=str, required=False, location="json") args = parser.parse_args() - server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first() + server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first() if not server: raise NotFound() @@ -104,8 +104,8 @@ class AppMCPServerRefreshController(Resource): raise NotFound() server = ( db.session.query(AppMCPServer) - .filter(AppMCPServer.id == server_id) - .filter(AppMCPServer.tenant_id == current_user.current_tenant_id) + .where(AppMCPServer.id == server_id) + .where(AppMCPServer.tenant_id == current_user.current_tenant_id) .first() ) if not server: diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index ea659f9f5..5e79e8dec 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -56,7 +56,7 @@ class ChatMessageListApi(Resource): conversation = ( db.session.query(Conversation) - .filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id) + .where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id) .first() ) @@ -66,7 +66,7 @@ class ChatMessageListApi(Resource): if args["first_id"]: first_message = ( db.session.query(Message) - .filter(Message.conversation_id == conversation.id, Message.id == args["first_id"]) + .where(Message.conversation_id == conversation.id, Message.id == args["first_id"]) .first() ) @@ -75,7 +75,7 @@ class ChatMessageListApi(Resource): history_messages = ( db.session.query(Message) - .filter( + .where( Message.conversation_id == conversation.id, Message.created_at < first_message.created_at, Message.id != first_message.id, @@ -87,7 +87,7 @@ class ChatMessageListApi(Resource): else: history_messages = ( db.session.query(Message) - .filter(Message.conversation_id == conversation.id) + .where(Message.conversation_id == conversation.id) .order_by(Message.created_at.desc()) .limit(args["limit"]) .all() @@ -98,7 +98,7 @@ class ChatMessageListApi(Resource): current_page_first_message = history_messages[-1] rest_count = ( db.session.query(Message) - .filter( + .where( Message.conversation_id == conversation.id, Message.created_at < current_page_first_message.created_at, Message.id != current_page_first_message.id, @@ -167,7 +167,7 @@ class MessageAnnotationCountApi(Resource): @account_initialization_required @get_app_model def get(self, app_model): - count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count() + count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count() return {"count": count} @@ -214,7 +214,7 @@ class MessageApi(Resource): def get(self, app_model, message_id): message_id = str(message_id) - message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() + message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() if not message: raise NotFound("Message Not Exists.") diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index f30e3e893..029138fb6 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -42,7 +42,7 @@ class ModelConfigResource(Resource): if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: # get original app model config original_app_model_config = ( - db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() + db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first() ) if original_app_model_config is None: raise ValueError("Original app model config not found") diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 358a5e8cd..03418f1dd 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -49,7 +49,7 @@ class AppSite(Resource): if not current_user.is_editor: raise Forbidden() - site = db.session.query(Site).filter(Site.app_id == app_model.id).first() + site = db.session.query(Site).where(Site.app_id == app_model.id).first() if not site: raise NotFound @@ -93,7 +93,7 @@ class AppSiteAccessTokenReset(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - site = db.session.query(Site).filter(Site.app_id == app_model.id).first() + site = db.session.query(Site).where(Site.app_id == app_model.id).first() if not site: raise NotFound diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 3322350e2..132dc1f96 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -11,7 +11,7 @@ from models import App, AppMode def _load_app_model(app_id: str) -> Optional[App]: app_model = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) return app_model diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index b49f8affc..39f8ab578 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -30,7 +30,7 @@ class DataSourceApi(Resource): # get workspace data source integrates data_source_integrates = ( db.session.query(DataSourceOauthBinding) - .filter( + .where( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.disabled == False, ) @@ -171,7 +171,7 @@ class DataSourceNotionApi(Resource): page_id = str(page_id) with Session(db.engine) as session: data_source_binding = session.execute( - select(DataSourceOauthBinding).filter( + select(DataSourceOauthBinding).where( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.provider == "notion", diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 4f62ac78b..f551bc243 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -412,7 +412,7 @@ class DatasetIndexingEstimateApi(Resource): file_ids = args["info_list"]["file_info_list"]["file_ids"] file_details = ( db.session.query(UploadFile) - .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) + .where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) .all() ) @@ -517,14 +517,14 @@ class DatasetIndexingStatusApi(Resource): dataset_id = str(dataset_id) documents = ( db.session.query(Document) - .filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id) + .where(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id) .all() ) documents_status = [] for document in documents: completed_segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment", @@ -533,7 +533,7 @@ class DatasetIndexingStatusApi(Resource): ) total_segments = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .count() ) # Create a dictionary with document attributes and additional fields @@ -568,7 +568,7 @@ class DatasetApiKeyApi(Resource): def get(self): keys = ( db.session.query(ApiToken) - .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) + .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) .all() ) return {"items": keys} @@ -584,7 +584,7 @@ class DatasetApiKeyApi(Resource): current_key_count = ( db.session.query(ApiToken) - .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) + .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) .count() ) @@ -620,7 +620,7 @@ class DatasetApiDeleteApi(Resource): key = ( db.session.query(ApiToken) - .filter( + .where( ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id, @@ -631,7 +631,7 @@ class DatasetApiDeleteApi(Resource): if key is None: flask_restful.abort(404, message="API key not found") - db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() + db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.commit() return {"result": "success"}, 204 diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 28a2e9304..d14b208a4 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -124,7 +124,7 @@ class GetProcessRuleApi(Resource): # get the latest process rule dataset_process_rule = ( db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.dataset_id == document.dataset_id) + .where(DatasetProcessRule.dataset_id == document.dataset_id) .order_by(DatasetProcessRule.created_at.desc()) .limit(1) .one_or_none() @@ -176,7 +176,7 @@ class DatasetDocumentListApi(Resource): if search: search = f"%{search}%" - query = query.filter(Document.name.like(search)) + query = query.where(Document.name.like(search)) if sort.startswith("-"): sort_logic = desc @@ -212,7 +212,7 @@ class DatasetDocumentListApi(Resource): for document in documents: completed_segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment", @@ -221,7 +221,7 @@ class DatasetDocumentListApi(Resource): ) total_segments = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .count() ) document.completed_segments = completed_segments @@ -417,7 +417,7 @@ class DocumentIndexingEstimateApi(DocumentResource): file = ( db.session.query(UploadFile) - .filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) + .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) .first() ) @@ -492,7 +492,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): file_id = data_source_info["upload_file_id"] file_detail = ( db.session.query(UploadFile) - .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id) + .where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id) .first() ) @@ -568,7 +568,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource): for document in documents: completed_segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment", @@ -577,7 +577,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource): ) total_segments = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .count() ) # Create a dictionary with document attributes and additional fields @@ -611,7 +611,7 @@ class DocumentIndexingStatusApi(DocumentResource): completed_segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment", @@ -620,7 +620,7 @@ class DocumentIndexingStatusApi(DocumentResource): ) total_segments = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment") + .where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment") .count() ) diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 48142dbe7..b3704ce8b 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -78,7 +78,7 @@ class DatasetDocumentSegmentListApi(Resource): query = ( select(DocumentSegment) - .filter( + .where( DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id, ) @@ -86,19 +86,19 @@ class DatasetDocumentSegmentListApi(Resource): ) if status_list: - query = query.filter(DocumentSegment.status.in_(status_list)) + query = query.where(DocumentSegment.status.in_(status_list)) if hit_count_gte is not None: - query = query.filter(DocumentSegment.hit_count >= hit_count_gte) + query = query.where(DocumentSegment.hit_count >= hit_count_gte) if keyword: query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) if args["enabled"].lower() != "all": if args["enabled"].lower() == "true": - query = query.filter(DocumentSegment.enabled == True) + query = query.where(DocumentSegment.enabled == True) elif args["enabled"].lower() == "false": - query = query.filter(DocumentSegment.enabled == False) + query = query.where(DocumentSegment.enabled == False) segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) @@ -285,7 +285,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .first() ) if not segment: @@ -331,7 +331,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .first() ) if not segment: @@ -436,7 +436,7 @@ class ChildChunkAddApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .first() ) if not segment: @@ -493,7 +493,7 @@ class ChildChunkAddApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .first() ) if not segment: @@ -540,7 +540,7 @@ class ChildChunkAddApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .first() ) if not segment: @@ -586,7 +586,7 @@ class ChildChunkUpdateApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .first() ) if not segment: @@ -595,7 +595,7 @@ class ChildChunkUpdateApi(Resource): child_chunk_id = str(child_chunk_id) child_chunk = ( db.session.query(ChildChunk) - .filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) + .where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) .first() ) if not child_chunk: @@ -635,7 +635,7 @@ class ChildChunkUpdateApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .first() ) if not segment: @@ -644,7 +644,7 @@ class ChildChunkUpdateApi(Resource): child_chunk_id = str(child_chunk_id) child_chunk = ( db.session.query(ChildChunk) - .filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) + .where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) .first() ) if not child_chunk: diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 29111fb86..ffdf73c36 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -34,11 +34,11 @@ class InstalledAppsListApi(Resource): if app_id: installed_apps = ( db.session.query(InstalledApp) - .filter(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)) + .where(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)) .all() ) else: - installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all() + installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all() current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) installed_app_list: list[dict[str, Any]] = [ @@ -94,12 +94,12 @@ class InstalledAppsListApi(Resource): parser.add_argument("app_id", type=str, required=True, help="Invalid app_id") args = parser.parse_args() - recommended_app = db.session.query(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]).first() + recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first() if recommended_app is None: raise NotFound("App not found") current_tenant_id = current_user.current_tenant_id - app = db.session.query(App).filter(App.id == args["app_id"]).first() + app = db.session.query(App).where(App.id == args["app_id"]).first() if app is None: raise NotFound("App not found") @@ -109,7 +109,7 @@ class InstalledAppsListApi(Resource): installed_app = ( db.session.query(InstalledApp) - .filter(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)) + .where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)) .first() ) diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index afbd78bd5..de97fb149 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -28,7 +28,7 @@ def installed_app_required(view=None): installed_app = ( db.session.query(InstalledApp) - .filter( + .where( InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id ) .first() diff --git a/api/controllers/console/workspace/__init__.py b/api/controllers/console/workspace/__init__.py index 072e904ca..ef814dd73 100644 --- a/api/controllers/console/workspace/__init__.py +++ b/api/controllers/console/workspace/__init__.py @@ -21,7 +21,7 @@ def plugin_permission_required( with Session(db.engine) as session: permission = ( session.query(TenantPluginPermission) - .filter( + .where( TenantPluginPermission.tenant_id == tenant_id, ) .first() diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 7f7e64a59..5cd2e0cd2 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -68,7 +68,7 @@ class AccountInitApi(Resource): # check invitation code invitation_code = ( db.session.query(InvitationCode) - .filter( + .where( InvitationCode.code == args["invitation_code"], InvitationCode.status == "unused", ) @@ -228,7 +228,7 @@ class AccountIntegrateApi(Resource): def get(self): account = current_user - account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all() + account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all() base_url = request.url_root.rstrip("/") oauth_base_path = "/console/api/oauth/login" diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index b1f79ffde..f7424923b 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -108,7 +108,7 @@ class MemberCancelInviteApi(Resource): @login_required @account_initialization_required def delete(self, member_id): - member = db.session.query(Account).filter(Account.id == str(member_id)).first() + member = db.session.query(Account).where(Account.id == str(member_id)).first() if member is None: abort(404) else: diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index 50408e092..b533614d4 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -22,7 +22,7 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser: user_id = "DEFAULT-USER" if user_id == "DEFAULT-USER": - user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first() + user_model = session.query(EndUser).where(EndUser.session_id == "DEFAULT-USER").first() if not user_model: user_model = EndUser( tenant_id=tenant_id, @@ -36,7 +36,7 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser: else: user_model = AccountService.load_user(user_id) if not user_model: - user_model = session.query(EndUser).filter(EndUser.id == user_id).first() + user_model = session.query(EndUser).where(EndUser.id == user_id).first() if not user_model: raise ValueError("user not found") except Exception: @@ -71,7 +71,7 @@ def get_user_tenant(view: Optional[Callable] = None): try: tenant_model = ( db.session.query(Tenant) - .filter( + .where( Tenant.id == tenant_id, ) .first() diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index f3a9312dd..9e7b3d4f2 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -55,7 +55,7 @@ def enterprise_inner_api_user_auth(view): if signature_base64 != token: return view(*args, **kwargs) - kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first() + kwargs["user"] = db.session.query(EndUser).where(EndUser.id == user_id).first() return view(*args, **kwargs) diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index ead728bfb..87d678796 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -30,7 +30,7 @@ class MCPAppApi(Resource): request_id = args.get("id") - server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first() + server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first() if not server: return helper.compact_generate_response( create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found") @@ -41,7 +41,7 @@ class MCPAppApi(Resource): create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active") ) - app = db.session.query(App).filter(App.id == server.app_id).first() + app = db.session.query(App).where(App.id == server.app_id).first() if not app: return helper.compact_generate_response( create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found") diff --git a/api/controllers/service_api/app/site.py b/api/controllers/service_api/app/site.py index e752dfee3..c157b39f6 100644 --- a/api/controllers/service_api/app/site.py +++ b/api/controllers/service_api/app/site.py @@ -16,7 +16,7 @@ class AppSiteApi(Resource): @marshal_with(fields.site_fields) def get(self, app_model: App): """Retrieve app site info.""" - site = db.session.query(Site).filter(Site.app_id == app_model.id).first() + site = db.session.query(Site).where(Site.app_id == app_model.id).first() if not site: raise Forbidden() diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index d571b21a0..ac85c0b38 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -63,7 +63,7 @@ class DocumentAddByTextApi(DatasetApiResource): dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise ValueError("Dataset does not exist.") @@ -136,7 +136,7 @@ class DocumentUpdateByTextApi(DatasetApiResource): args = parser.parse_args() dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise ValueError("Dataset does not exist.") @@ -206,7 +206,7 @@ class DocumentAddByFileApi(DatasetApiResource): # get dataset info dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise ValueError("Dataset does not exist.") @@ -299,7 +299,7 @@ class DocumentUpdateByFileApi(DatasetApiResource): # get dataset info dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise ValueError("Dataset does not exist.") @@ -367,7 +367,7 @@ class DocumentDeleteApi(DatasetApiResource): tenant_id = str(tenant_id) # get dataset info - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise ValueError("Dataset does not exist.") @@ -398,7 +398,7 @@ class DocumentListApi(DatasetApiResource): page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) search = request.args.get("keyword", default=None, type=str) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") @@ -406,7 +406,7 @@ class DocumentListApi(DatasetApiResource): if search: search = f"%{search}%" - query = query.filter(Document.name.like(search)) + query = query.where(Document.name.like(search)) query = query.order_by(desc(Document.created_at), desc(Document.position)) @@ -430,7 +430,7 @@ class DocumentIndexingStatusApi(DatasetApiResource): batch = str(batch) tenant_id = str(tenant_id) # get dataset - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # get documents @@ -441,7 +441,7 @@ class DocumentIndexingStatusApi(DatasetApiResource): for document in documents: completed_segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment", @@ -450,7 +450,7 @@ class DocumentIndexingStatusApi(DatasetApiResource): ) total_segments = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .count() ) # Create a dictionary with document attributes and additional fields diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 403b7f0a0..31f862dc8 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -42,7 +42,7 @@ class SegmentApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check document @@ -89,7 +89,7 @@ class SegmentApi(DatasetApiResource): tenant_id = str(tenant_id) page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check document @@ -146,7 +146,7 @@ class DatasetSegmentApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -170,7 +170,7 @@ class DatasetSegmentApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -216,7 +216,7 @@ class DatasetSegmentApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -246,7 +246,7 @@ class ChildChunkApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") @@ -296,7 +296,7 @@ class ChildChunkApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") @@ -343,7 +343,7 @@ class DatasetChildChunkApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") @@ -382,7 +382,7 @@ class DatasetChildChunkApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") diff --git a/api/controllers/service_api/dataset/upload_file.py b/api/controllers/service_api/dataset/upload_file.py index 6382b63ea..3b4721b5b 100644 --- a/api/controllers/service_api/dataset/upload_file.py +++ b/api/controllers/service_api/dataset/upload_file.py @@ -17,7 +17,7 @@ class UploadFileApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check document @@ -31,7 +31,7 @@ class UploadFileApi(DatasetApiResource): data_source_info = document.data_source_info_dict if data_source_info and "upload_file_id" in data_source_info: file_id = data_source_info["upload_file_id"] - upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() if not upload_file: raise NotFound("UploadFile not found.") else: diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index eeed32143..da81cc8bc 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -44,7 +44,7 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio def decorated_view(*args, **kwargs): api_token = validate_and_get_api_token("app") - app_model = db.session.query(App).filter(App.id == api_token.app_id).first() + app_model = db.session.query(App).where(App.id == api_token.app_id).first() if not app_model: raise Forbidden("The app no longer exists.") @@ -54,7 +54,7 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio if not app_model.enable_api: raise Forbidden("The app's API service has been disabled.") - tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first() + tenant = db.session.query(Tenant).where(Tenant.id == app_model.tenant_id).first() if tenant is None: raise ValueError("Tenant does not exist.") if tenant.status == TenantStatus.ARCHIVE: @@ -62,15 +62,15 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio tenant_account_join = ( db.session.query(Tenant, TenantAccountJoin) - .filter(Tenant.id == api_token.tenant_id) - .filter(TenantAccountJoin.tenant_id == Tenant.id) - .filter(TenantAccountJoin.role.in_(["owner"])) - .filter(Tenant.status == TenantStatus.NORMAL) + .where(Tenant.id == api_token.tenant_id) + .where(TenantAccountJoin.tenant_id == Tenant.id) + .where(TenantAccountJoin.role.in_(["owner"])) + .where(Tenant.status == TenantStatus.NORMAL) .one_or_none() ) # TODO: only owner information is required, so only one is returned. if tenant_account_join: tenant, ta = tenant_account_join - account = db.session.query(Account).filter(Account.id == ta.account_id).first() + account = db.session.query(Account).where(Account.id == ta.account_id).first() # Login admin if account: account.current_tenant = tenant @@ -213,15 +213,15 @@ def validate_dataset_token(view=None): api_token = validate_and_get_api_token("dataset") tenant_account_join = ( db.session.query(Tenant, TenantAccountJoin) - .filter(Tenant.id == api_token.tenant_id) - .filter(TenantAccountJoin.tenant_id == Tenant.id) - .filter(TenantAccountJoin.role.in_(["owner"])) - .filter(Tenant.status == TenantStatus.NORMAL) + .where(Tenant.id == api_token.tenant_id) + .where(TenantAccountJoin.tenant_id == Tenant.id) + .where(TenantAccountJoin.role.in_(["owner"])) + .where(Tenant.status == TenantStatus.NORMAL) .one_or_none() ) # TODO: only owner information is required, so only one is returned. if tenant_account_join: tenant, ta = tenant_account_join - account = db.session.query(Account).filter(Account.id == ta.account_id).first() + account = db.session.query(Account).where(Account.id == ta.account_id).first() # Login admin if account: account.current_tenant = tenant @@ -293,7 +293,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] end_user = ( db.session.query(EndUser) - .filter( + .where( EndUser.tenant_id == app_model.tenant_id, EndUser.app_id == app_model.id, EndUser.session_id == user_id, @@ -320,7 +320,7 @@ class DatasetApiResource(Resource): method_decorators = [validate_dataset_token] def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset: - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first() if not dataset: raise NotFound("Dataset not found.") diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 10c3cdcf0..acd3a8b53 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -3,6 +3,7 @@ from datetime import UTC, datetime, timedelta from flask import request from flask_restful import Resource +from sqlalchemy import func, select from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config @@ -42,17 +43,17 @@ class PassportResource(Resource): raise WebAppAuthRequiredError() # get site from db and check if it is normal - site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() + site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal")) if not site: raise NotFound() # get app from db and check if it is normal and enable_site - app_model = db.session.query(App).filter(App.id == site.app_id).first() + app_model = db.session.scalar(select(App).where(App.id == site.app_id)) if not app_model or app_model.status != "normal" or not app_model.enable_site: raise NotFound() if user_id: - end_user = ( - db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first() + end_user = db.session.scalar( + select(EndUser).where(EndUser.app_id == app_model.id, EndUser.session_id == user_id) ) if end_user: @@ -121,11 +122,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: if not user_auth_type: raise Unauthorized("Missing auth_type in the token.") - site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() + site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal")) if not site: raise NotFound() - app_model = db.session.query(App).filter(App.id == site.app_id).first() + app_model = db.session.scalar(select(App).where(App.id == site.app_id)) if not app_model or app_model.status != "normal" or not app_model.enable_site: raise NotFound() @@ -140,16 +141,14 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: end_user = None if end_user_id: - end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() + end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id)) if session_id: - end_user = ( - db.session.query(EndUser) - .filter( + end_user = db.session.scalar( + select(EndUser).where( EndUser.session_id == session_id, EndUser.tenant_id == app_model.tenant_id, EndUser.app_id == app_model.id, ) - .first() ) if not end_user: if not session_id: @@ -187,8 +186,8 @@ def _exchange_for_public_app_token(app_model, site, token_decoded): user_id = token_decoded.get("user_id") end_user = None if user_id: - end_user = ( - db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first() + end_user = db.session.scalar( + select(EndUser).where(EndUser.app_id == app_model.id, EndUser.session_id == user_id) ) if not end_user: @@ -224,6 +223,8 @@ def generate_session_id(): """ while True: session_id = str(uuid.uuid4()) - existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count() + existing_count = db.session.scalar( + select(func.count()).select_from(EndUser).where(EndUser.session_id == session_id) + ) if existing_count == 0: return session_id diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 0564b15ea..3c133499b 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -57,7 +57,7 @@ class AppSiteApi(WebApiResource): def get(self, app_model, end_user): """Retrieve app site info.""" # get site - site = db.session.query(Site).filter(Site.app_id == app_model.id).first() + site = db.session.query(Site).where(Site.app_id == app_model.id).first() if not site: raise Forbidden() diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 154bddfc5..ae6f14a68 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -3,6 +3,7 @@ from functools import wraps from flask import request from flask_restful import Resource +from sqlalchemy import select from werkzeug.exceptions import BadRequest, NotFound, Unauthorized from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError @@ -48,8 +49,8 @@ def decode_jwt_token(): decoded = PassportService().verify(tk) app_code = decoded.get("app_code") app_id = decoded.get("app_id") - app_model = db.session.query(App).filter(App.id == app_id).first() - site = db.session.query(Site).filter(Site.code == app_code).first() + app_model = db.session.scalar(select(App).where(App.id == app_id)) + site = db.session.scalar(select(Site).where(Site.code == app_code)) if not app_model: raise NotFound() if not app_code or not site: @@ -57,7 +58,7 @@ def decode_jwt_token(): if app_model.enable_site is False: raise BadRequest("Site is disabled.") end_user_id = decoded.get("end_user_id") - end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() + end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id)) if not end_user: raise NotFound() diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 28bf4a9a2..1f3c218d5 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -99,7 +99,7 @@ class BaseAgentRunner(AppRunner): # get how many agent thoughts have been created self.agent_thought_count = ( db.session.query(MessageAgentThought) - .filter( + .where( MessageAgentThought.message_id == self.message.id, ) .count() @@ -336,7 +336,7 @@ class BaseAgentRunner(AppRunner): Save agent thought """ updated_agent_thought = ( - db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() + db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought.id).first() ) if not updated_agent_thought: raise ValueError("agent thought not found") @@ -496,7 +496,7 @@ class BaseAgentRunner(AppRunner): return result def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: - files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() + files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all() if not files: return UserPromptMessage(content=message.query) if message.app_model_config: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 80af9a3c6..a75e17af6 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -72,7 +72,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) - app_record = db.session.query(App).filter(App.id == app_config.app_id).first() + app_record = db.session.query(App).where(App.id == app_config.app_id).first() if not app_record: raise ValueError("App not found") diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 71328f6d1..39d6ba39f 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -45,7 +45,7 @@ class AgentChatAppRunner(AppRunner): app_config = application_generate_entity.app_config app_config = cast(AgentChatAppConfig, app_config) - app_record = db.session.query(App).filter(App.id == app_config.app_id).first() + app_record = db.session.query(App).where(App.id == app_config.app_id).first() if not app_record: raise ValueError("App not found") @@ -183,10 +183,10 @@ class AgentChatAppRunner(AppRunner): if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING - conversation_result = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() + conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first() if conversation_result is None: raise ValueError("Conversation not found") - message_result = db.session.query(Message).filter(Message.id == message.id).first() + message_result = db.session.query(Message).where(Message.id == message.id).first() if message_result is None: raise ValueError("Message not found") db.session.close() diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 39597fc03..894d7906d 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -43,7 +43,7 @@ class ChatAppRunner(AppRunner): app_config = application_generate_entity.app_config app_config = cast(ChatAppConfig, app_config) - app_record = db.session.query(App).filter(App.id == app_config.app_id).first() + app_record = db.session.query(App).where(App.id == app_config.app_id).first() if not app_record: raise ValueError("App not found") diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 195e7e2e3..9356bd1ce 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -248,7 +248,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): """ message = ( db.session.query(Message) - .filter( + .where( Message.id == message_id, Message.app_id == app_model.id, Message.from_source == ("api" if isinstance(user, EndUser) else "console"), diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 80fdd0b80..50d2a0036 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -36,7 +36,7 @@ class CompletionAppRunner(AppRunner): app_config = application_generate_entity.app_config app_config = cast(CompletionAppConfig, app_config) - app_record = db.session.query(App).filter(App.id == app_config.app_id).first() + app_record = db.session.query(App).where(App.id == app_config.app_id).first() if not app_record: raise ValueError("App not found") diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index d50cf1c94..f5bc480f0 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -85,7 +85,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): if conversation: app_model_config = ( db.session.query(AppModelConfig) - .filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) + .where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) .first() ) @@ -259,7 +259,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): :param conversation_id: conversation id :return: conversation """ - conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() + conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first() if not conversation: raise ConversationNotExistsError("Conversation not exists") @@ -272,7 +272,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): :param message_id: message id :return: message """ - message = db.session.query(Message).filter(Message.id == message_id).first() + message = db.session.query(Message).where(Message.id == message_id).first() if message is None: raise MessageNotExistsError("Message not exists") diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 83fd3deba..54dc69302 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -26,7 +26,7 @@ class AnnotationReplyFeature: :return: """ annotation_setting = ( - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_record.id).first() + db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first() ) if not annotation_setting: diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 3c8c7bb5a..888434798 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -471,7 +471,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): :return: """ agent_thought: Optional[MessageAgentThought] = ( - db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first() + db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first() ) if agent_thought: diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 2343081ea..824da0b93 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -81,7 +81,7 @@ class MessageCycleManager: def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str): with flask_app.app_context(): # get conversation and message - conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() + conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first() if not conversation: return @@ -140,7 +140,7 @@ class MessageCycleManager: :param event: event :return: """ - message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first() + message_file = db.session.query(MessageFile).where(MessageFile.id == event.message_file_id).first() if message_file and message_file.url is not None: # get tool file id diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index a3a7b4b81..c55ba5e0f 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -49,7 +49,7 @@ class DatasetIndexToolCallbackHandler: for document in documents: if document.metadata is not None: document_id = document.metadata["document_id"] - dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() + dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() if not dataset_document: _logger.warning( "Expected DatasetDocument record to exist, but none was found, document_id=%s", @@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: child_chunk = ( db.session.query(ChildChunk) - .filter( + .where( ChildChunk.index_node_id == document.metadata["doc_id"], ChildChunk.dataset_id == dataset_document.dataset_id, ChildChunk.document_id == dataset_document.id, @@ -69,18 +69,18 @@ class DatasetIndexToolCallbackHandler: if child_chunk: segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.id == child_chunk.segment_id) + .where(DocumentSegment.id == child_chunk.segment_id) .update( {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False ) ) else: - query = db.session.query(DocumentSegment).filter( + query = db.session.query(DocumentSegment).where( DocumentSegment.index_node_id == document.metadata["doc_id"] ) if "dataset_id" in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"]) # add hit count to document segment query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 66d8d0f41..af5c18e26 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -191,7 +191,7 @@ class ProviderConfiguration(BaseModel): provider_record = ( db.session.query(Provider) - .filter( + .where( Provider.tenant_id == self.tenant_id, Provider.provider_type == ProviderType.CUSTOM.value, Provider.provider_name.in_(provider_names), @@ -351,7 +351,7 @@ class ProviderConfiguration(BaseModel): provider_model_record = ( db.session.query(ProviderModel) - .filter( + .where( ProviderModel.tenant_id == self.tenant_id, ProviderModel.provider_name.in_(provider_names), ProviderModel.model_name == model, @@ -481,7 +481,7 @@ class ProviderConfiguration(BaseModel): return ( db.session.query(ProviderModelSetting) - .filter( + .where( ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.provider_name.in_(provider_names), ProviderModelSetting.model_type == model_type.to_origin_model_type(), @@ -560,7 +560,7 @@ class ProviderConfiguration(BaseModel): return ( db.session.query(LoadBalancingModelConfig) - .filter( + .where( LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(provider_names), LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), @@ -583,7 +583,7 @@ class ProviderConfiguration(BaseModel): load_balancing_config_count = ( db.session.query(LoadBalancingModelConfig) - .filter( + .where( LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(provider_names), LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), @@ -627,7 +627,7 @@ class ProviderConfiguration(BaseModel): model_setting = ( db.session.query(ProviderModelSetting) - .filter( + .where( ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.provider_name.in_(provider_names), ProviderModelSetting.model_type == model_type.to_origin_model_type(), @@ -693,7 +693,7 @@ class ProviderConfiguration(BaseModel): preferred_model_provider = ( db.session.query(TenantPreferredModelProvider) - .filter( + .where( TenantPreferredModelProvider.tenant_id == self.tenant_id, TenantPreferredModelProvider.provider_name.in_(provider_names), ) diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 53acdf075..2099a9e34 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -32,7 +32,7 @@ class ApiExternalDataTool(ExternalDataTool): # get api_based_extension api_based_extension = ( db.session.query(APIBasedExtension) - .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) .first() ) @@ -56,7 +56,7 @@ class ApiExternalDataTool(ExternalDataTool): # get api_based_extension api_based_extension = ( db.session.query(APIBasedExtension) - .filter(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id) + .where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id) .first() ) diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index 1e40997a8..f761d2037 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -15,7 +15,7 @@ def encrypt_token(tenant_id: str, token: str): from models.account import Tenant from models.engine import db - if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()): + if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()): raise ValueError(f"Tenant with id {tenant_id} not found") encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) return base64.b64encode(encrypted_token).decode() diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index e5976f4c9..fc5d0547f 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -59,7 +59,7 @@ class IndexingRunner: # get the process rule processing_rule = ( db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .first() ) if not processing_rule: @@ -119,12 +119,12 @@ class IndexingRunner: db.session.delete(document_segment) if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: # delete child chunks - db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete() + db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete() db.session.commit() # get the process rule processing_rule = ( db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .first() ) if not processing_rule: @@ -212,7 +212,7 @@ class IndexingRunner: # get the process rule processing_rule = ( db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .first() ) @@ -316,7 +316,7 @@ class IndexingRunner: # delete image files and related db records image_upload_file_ids = get_image_upload_file_ids(document.page_content) for upload_file_id in image_upload_file_ids: - image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() if image_file is None: continue try: @@ -346,7 +346,7 @@ class IndexingRunner: raise ValueError("no upload file found") file_detail = ( - db.session.query(UploadFile).filter(UploadFile.id == data_source_info["upload_file_id"]).one_or_none() + db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none() ) if file_detail: @@ -599,7 +599,7 @@ class IndexingRunner: keyword.create(documents) if dataset.indexing_technique != "high_quality": document_ids = [document.metadata["doc_id"] for document in documents] - db.session.query(DocumentSegment).filter( + db.session.query(DocumentSegment).where( DocumentSegment.document_id == document_id, DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id.in_(document_ids), @@ -630,7 +630,7 @@ class IndexingRunner: index_processor.load(dataset, chunk_documents, with_keywords=False) document_ids = [document.metadata["doc_id"] for document in chunk_documents] - db.session.query(DocumentSegment).filter( + db.session.query(DocumentSegment).where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(document_ids), diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 20ff7e752..496b5432a 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -28,7 +28,7 @@ class MCPServerStreamableHTTPRequestHandler: ): self.app = app self.request = request - mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.app.id).first() + mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == self.app.id).first() if not mcp_server: raise ValueError("MCP server not found") self.mcp_server: AppMCPServer = mcp_server @@ -192,7 +192,7 @@ class MCPServerStreamableHTTPRequestHandler: def retrieve_end_user(self): return ( db.session.query(EndUser) - .filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp") + .where(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp") .first() ) diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index a9f0a92e5..7ce124594 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -67,7 +67,7 @@ class TokenBufferMemory: prompt_messages: list[PromptMessage] = [] for message in messages: - files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() + files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all() if files: file_extra_config = None if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}: diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index c65a3885f..332381555 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -89,7 +89,7 @@ class ApiModeration(Moderation): def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]: extension = ( db.session.query(APIBasedExtension) - .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) .first() ) diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index bbbc12a2c..cf367efdf 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -120,7 +120,7 @@ class AliyunDataTrace(BaseTraceInstance): user_id = message_data.from_account_id if message_data.from_end_user_id: end_user_data: Optional[EndUser] = ( - db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() + db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: user_id = end_user_data.session_id @@ -244,14 +244,14 @@ class AliyunDataTrace(BaseTraceInstance): if not app_id: raise ValueError("No app_id found in trace_info metadata") - app = session.query(App).filter(App.id == app_id).first() + app = session.query(App).where(App.id == app_id).first() if not app: raise ValueError(f"App with id {app_id} not found") if not app.created_by: raise ValueError(f"App with id {app_id} has no creator (created_by is None)") - service_account = session.query(Account).filter(Account.id == app.created_by).first() + service_account = session.query(Account).where(Account.id == app.created_by).first() if not service_account: raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") current_tenant = ( diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 14dba4423..1b72a4775 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -297,7 +297,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): # Add end user data if available if trace_info.message_data.from_end_user_id: end_user_data: Optional[EndUser] = ( - db.session.query(EndUser).filter(EndUser.id == trace_info.message_data.from_end_user_id).first() + db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first() ) if end_user_data is not None: message_metadata["end_user_id"] = end_user_data.session_id @@ -703,7 +703,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): WorkflowNodeExecutionModel.process_data, WorkflowNodeExecutionModel.execution_metadata, ) - .filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) + .where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) .all() ) return workflow_nodes diff --git a/api/core/ops/base_trace_instance.py b/api/core/ops/base_trace_instance.py index 8593198bc..f8e428daf 100644 --- a/api/core/ops/base_trace_instance.py +++ b/api/core/ops/base_trace_instance.py @@ -44,14 +44,14 @@ class BaseTraceInstance(ABC): """ with Session(db.engine, expire_on_commit=False) as session: # Get the app to find its creator - app = session.query(App).filter(App.id == app_id).first() + app = session.query(App).where(App.id == app_id).first() if not app: raise ValueError(f"App with id {app_id} not found") if not app.created_by: raise ValueError(f"App with id {app_id} has no creator (created_by is None)") - service_account = session.query(Account).filter(Account.id == app.created_by).first() + service_account = session.query(Account).where(Account.id == app.created_by).first() if not service_account: raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 6dadb2897..f4a59ef3a 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -244,7 +244,7 @@ class LangFuseDataTrace(BaseTraceInstance): user_id = message_data.from_account_id if message_data.from_end_user_id: end_user_data: Optional[EndUser] = ( - db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() + db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: user_id = end_user_data.session_id diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 324678227..c97846dc9 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -262,7 +262,7 @@ class LangSmithDataTrace(BaseTraceInstance): if message_data.from_end_user_id: end_user_data: Optional[EndUser] = ( - db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() + db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: end_user_id = end_user_data.session_id diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index dfa7052c3..6079b2fae 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -284,7 +284,7 @@ class OpikDataTrace(BaseTraceInstance): if message_data.from_end_user_id: end_user_data: Optional[EndUser] = ( - db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() + db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: end_user_id = end_user_data.session_id diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 34963efab..2b546b47c 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -218,7 +218,7 @@ class OpsTraceManager: """ trace_config_data: Optional[TraceAppConfig] = ( db.session.query(TraceAppConfig) - .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() ) @@ -226,7 +226,7 @@ class OpsTraceManager: return None # decrypt_token - app = db.session.query(App).filter(App.id == app_id).first() + app = db.session.query(App).where(App.id == app_id).first() if not app: raise ValueError("App not found") @@ -253,7 +253,7 @@ class OpsTraceManager: if app_id is None: return None - app: Optional[App] = db.session.query(App).filter(App.id == app_id).first() + app: Optional[App] = db.session.query(App).where(App.id == app_id).first() if app is None: return None @@ -293,18 +293,18 @@ class OpsTraceManager: @classmethod def get_app_config_through_message_id(cls, message_id: str): app_model_config = None - message_data = db.session.query(Message).filter(Message.id == message_id).first() + message_data = db.session.query(Message).where(Message.id == message_id).first() if not message_data: return None conversation_id = message_data.conversation_id - conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() + conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first() if not conversation_data: return None if conversation_data.app_model_config_id: app_model_config = ( db.session.query(AppModelConfig) - .filter(AppModelConfig.id == conversation_data.app_model_config_id) + .where(AppModelConfig.id == conversation_data.app_model_config_id) .first() ) elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs: @@ -331,7 +331,7 @@ class OpsTraceManager: if tracing_provider is not None: raise ValueError(f"Invalid tracing provider: {tracing_provider}") - app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first() + app_config: Optional[App] = db.session.query(App).where(App.id == app_id).first() if not app_config: raise ValueError("App not found") app_config.tracing = json.dumps( @@ -349,7 +349,7 @@ class OpsTraceManager: :param app_id: app id :return: """ - app: Optional[App] = db.session.query(App).filter(App.id == app_id).first() + app: Optional[App] = db.session.query(App).where(App.id == app_id).first() if not app: raise ValueError("App not found") if not app.tracing: diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index 36d060afd..573e8cac8 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -3,6 +3,8 @@ from datetime import datetime from typing import Optional, Union from urllib.parse import urlparse +from sqlalchemy import select + from extensions.ext_database import db from models.model import Message @@ -20,7 +22,7 @@ def filter_none_values(data: dict): def get_message_data(message_id: str): - return db.session.query(Message).filter(Message.id == message_id).first() + return db.session.scalar(select(Message).where(Message.id == message_id)) @contextmanager diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 4bd41ce4a..a34b3b780 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -235,7 +235,7 @@ class WeaveDataTrace(BaseTraceInstance): if message_data.from_end_user_id: end_user_data: Optional[EndUser] = ( - db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() + db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: end_user_id = end_user_data.session_id diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index 4e43561a1..e8c9bed09 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -193,9 +193,9 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): get the user by user id """ - user = db.session.query(EndUser).filter(EndUser.id == user_id).first() + user = db.session.query(EndUser).where(EndUser.id == user_id).first() if not user: - user = db.session.query(Account).filter(Account.id == user_id).first() + user = db.session.query(Account).where(Account.id == user_id).first() if not user: raise ValueError("user not found") @@ -208,7 +208,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): get app """ try: - app = db.session.query(App).filter(App.id == app_id).filter(App.tenant_id == tenant_id).first() + app = db.session.query(App).where(App.id == app_id).where(App.tenant_id == tenant_id).first() except Exception: raise ValueError("app not found") diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 488a39467..6de4f3a30 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -275,7 +275,7 @@ class ProviderManager: # Get the corresponding TenantDefaultModel record default_model = ( db.session.query(TenantDefaultModel) - .filter( + .where( TenantDefaultModel.tenant_id == tenant_id, TenantDefaultModel.model_type == model_type.to_origin_model_type(), ) @@ -367,7 +367,7 @@ class ProviderManager: # Get the list of available models from get_configurations and check if it is LLM default_model = ( db.session.query(TenantDefaultModel) - .filter( + .where( TenantDefaultModel.tenant_id == tenant_id, TenantDefaultModel.model_type == model_type.to_origin_model_type(), ) @@ -541,7 +541,7 @@ class ProviderManager: db.session.rollback() existed_provider_record = ( db.session.query(Provider) - .filter( + .where( Provider.tenant_id == tenant_id, Provider.provider_name == ModelProviderID(provider_name).provider_name, Provider.provider_type == ProviderType.SYSTEM.value, diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index d6d0bd88b..ec3a23bd9 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -93,11 +93,11 @@ class Jieba(BaseKeyword): documents = [] for chunk_index in sorted_chunk_indices: - segment_query = db.session.query(DocumentSegment).filter( + segment_query = db.session.query(DocumentSegment).where( DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index ) if document_ids_filter: - segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter)) + segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter)) segment = segment_query.first() if segment: @@ -214,7 +214,7 @@ class Jieba(BaseKeyword): def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]): document_segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id) + .where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id) .first() ) if document_segment: diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 5a6903d3d..e872a4e37 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -127,7 +127,7 @@ class RetrievalService: external_retrieval_model: Optional[dict] = None, metadata_filtering_conditions: Optional[dict] = None, ): - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: return [] metadata_condition = ( @@ -145,7 +145,7 @@ class RetrievalService: @classmethod def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]: with Session(db.engine) as session: - return session.query(Dataset).filter(Dataset.id == dataset_id).first() + return session.query(Dataset).where(Dataset.id == dataset_id).first() @classmethod def keyword_search( @@ -294,7 +294,7 @@ class RetrievalService: dataset_documents = { doc.id: doc for doc in db.session.query(DatasetDocument) - .filter(DatasetDocument.id.in_(document_ids)) + .where(DatasetDocument.id.in_(document_ids)) .options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id)) .all() } @@ -318,7 +318,7 @@ class RetrievalService: child_index_node_id = document.metadata.get("doc_id") child_chunk = ( - db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first() + db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first() ) if not child_chunk: @@ -326,7 +326,7 @@ class RetrievalService: segment = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.dataset_id == dataset_document.dataset_id, DocumentSegment.enabled == True, DocumentSegment.status == "completed", @@ -381,7 +381,7 @@ class RetrievalService: segment = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.dataset_id == dataset_document.dataset_id, DocumentSegment.enabled == True, DocumentSegment.status == "completed", diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 05fa73011..dfb95a183 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -443,7 +443,7 @@ class QdrantVectorFactory(AbstractVectorFactory): if dataset.collection_binding_id: dataset_collection_binding = ( db.session.query(DatasetCollectionBinding) - .filter(DatasetCollectionBinding.id == dataset.collection_binding_id) + .where(DatasetCollectionBinding.id == dataset.collection_binding_id) .one_or_none() ) if dataset_collection_binding: diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index 6f895b12a..ba6a9654f 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -418,13 +418,13 @@ class TidbOnQdrantVector(BaseVector): class TidbOnQdrantVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector: tidb_auth_binding = ( - db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() + db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() ) if not tidb_auth_binding: with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): tidb_auth_binding = ( db.session.query(TidbAuthBinding) - .filter(TidbAuthBinding.tenant_id == dataset.tenant_id) + .where(TidbAuthBinding.tenant_id == dataset.tenant_id) .one_or_none() ) if tidb_auth_binding: @@ -433,7 +433,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): else: idle_tidb_auth_binding = ( db.session.query(TidbAuthBinding) - .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") + .where(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") .limit(1) .one_or_none() ) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 00080b0fa..e018f7d3d 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -47,7 +47,7 @@ class Vector: if dify_config.VECTOR_STORE_WHITELIST_ENABLE: whitelist = ( db.session.query(Whitelist) - .filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db") + .where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db") .one_or_none() ) if whitelist: diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 398b0daad..f844770a2 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -42,7 +42,7 @@ class DatasetDocumentStore: @property def docs(self) -> dict[str, Document]: document_segments = ( - db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == self._dataset.id).all() + db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id).all() ) output = {} @@ -63,7 +63,7 @@ class DatasetDocumentStore: def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None: max_position = ( db.session.query(func.max(DocumentSegment.position)) - .filter(DocumentSegment.document_id == self._document_id) + .where(DocumentSegment.document_id == self._document_id) .scalar() ) @@ -147,7 +147,7 @@ class DatasetDocumentStore: segment_document.tokens = tokens if save_child and doc.children: # delete the existing child chunks - db.session.query(ChildChunk).filter( + db.session.query(ChildChunk).where( ChildChunk.tenant_id == self._dataset.tenant_id, ChildChunk.dataset_id == self._dataset.id, ChildChunk.document_id == self._document_id, @@ -230,7 +230,7 @@ class DatasetDocumentStore: def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]: document_segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id) + .where(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id) .first() ) diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 81a0810e2..875626eb3 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -366,7 +366,7 @@ class NotionExtractor(BaseExtractor): def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: data_source_binding = ( db.session.query(DataSourceOauthBinding) - .filter( + .where( db.and_( DataSourceOauthBinding.tenant_id == tenant_id, DataSourceOauthBinding.provider == "notion", diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 1cde5e1c8..52756fbac 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -118,7 +118,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): child_node_ids = ( db.session.query(ChildChunk.index_node_id) .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) - .filter( + .where( DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(node_ids), ChildChunk.dataset_id == dataset.id, @@ -128,7 +128,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): child_node_ids = [child_node_id[0] for child_node_id in child_node_ids] vector.delete_by_ids(child_node_ids) if delete_child_chunks: - db.session.query(ChildChunk).filter( + db.session.query(ChildChunk).where( ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids) ).delete() db.session.commit() @@ -136,7 +136,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): vector.delete() if delete_child_chunks: - db.session.query(ChildChunk).filter(ChildChunk.dataset_id == dataset.id).delete() + db.session.query(ChildChunk).where(ChildChunk.dataset_id == dataset.id).delete() db.session.commit() def retrieve( diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 3d0f0f97b..a25bc6564 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -135,7 +135,7 @@ class DatasetRetrieval: available_datasets = [] for dataset_id in dataset_ids: # get dataset from dataset id - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() # pass if dataset is not available if not dataset: @@ -242,7 +242,7 @@ class DatasetRetrieval: dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() document = ( db.session.query(DatasetDocument) - .filter( + .where( DatasetDocument.id == segment.document_id, DatasetDocument.enabled == True, DatasetDocument.archived == False, @@ -327,7 +327,7 @@ class DatasetRetrieval: if dataset_id: # get retrieval model config - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if dataset: results = [] if dataset.provider == "external": @@ -516,14 +516,14 @@ class DatasetRetrieval: if document.metadata is not None: dataset_document = ( db.session.query(DatasetDocument) - .filter(DatasetDocument.id == document.metadata["document_id"]) + .where(DatasetDocument.id == document.metadata["document_id"]) .first() ) if dataset_document: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: child_chunk = ( db.session.query(ChildChunk) - .filter( + .where( ChildChunk.index_node_id == document.metadata["doc_id"], ChildChunk.dataset_id == dataset_document.dataset_id, ChildChunk.document_id == dataset_document.id, @@ -533,7 +533,7 @@ class DatasetRetrieval: if child_chunk: segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.id == child_chunk.segment_id) + .where(DocumentSegment.id == child_chunk.segment_id) .update( {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False, @@ -541,13 +541,13 @@ class DatasetRetrieval: ) db.session.commit() else: - query = db.session.query(DocumentSegment).filter( + query = db.session.query(DocumentSegment).where( DocumentSegment.index_node_id == document.metadata["doc_id"] ) # if 'dataset_id' in document.metadata: if "dataset_id" in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"]) # add hit count to document segment query.update( @@ -600,7 +600,7 @@ class DatasetRetrieval: ): with flask_app.app_context(): with Session(db.engine) as session: - dataset = session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: return [] @@ -685,7 +685,7 @@ class DatasetRetrieval: available_datasets = [] for dataset_id in dataset_ids: # get dataset from dataset id - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() # pass if dataset is not available if not dataset: @@ -862,7 +862,7 @@ class DatasetRetrieval: metadata_filtering_conditions: Optional[MetadataFilteringCondition], inputs: dict, ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: - document_query = db.session.query(DatasetDocument).filter( + document_query = db.session.query(DatasetDocument).where( DatasetDocument.dataset_id.in_(dataset_ids), DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, @@ -930,9 +930,9 @@ class DatasetRetrieval: raise ValueError("Invalid metadata filtering mode") if filters: if metadata_filtering_conditions and metadata_filtering_conditions.logical_operator == "and": # type: ignore - document_query = document_query.filter(and_(*filters)) + document_query = document_query.where(and_(*filters)) else: - document_query = document_query.filter(or_(*filters)) + document_query = document_query.where(or_(*filters)) documents = document_query.all() # group by dataset_id metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore @@ -958,7 +958,7 @@ class DatasetRetrieval: self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig ) -> Optional[list[dict[str, Any]]]: # get all metadata field - metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all() + metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all() all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] # get metadata model config if metadata_model_config is None: diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py index fbe1d7913..95fab6151 100644 --- a/api/core/tools/custom_tool/provider.py +++ b/api/core/tools/custom_tool/provider.py @@ -178,7 +178,7 @@ class ApiToolProviderController(ToolProviderController): # get tenant api providers db_providers: list[ApiToolProvider] = ( db.session.query(ApiToolProvider) - .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name) + .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name) .all() ) diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index ece02f9d5..ff054041c 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -160,7 +160,7 @@ class ToolFileManager: with Session(self._engine, expire_on_commit=False) as session: tool_file: ToolFile | None = ( session.query(ToolFile) - .filter( + .where( ToolFile.id == id, ) .first() @@ -184,7 +184,7 @@ class ToolFileManager: with Session(self._engine, expire_on_commit=False) as session: message_file: MessageFile | None = ( session.query(MessageFile) - .filter( + .where( MessageFile.id == id, ) .first() @@ -204,7 +204,7 @@ class ToolFileManager: tool_file: ToolFile | None = ( session.query(ToolFile) - .filter( + .where( ToolFile.id == tool_file_id, ) .first() @@ -228,7 +228,7 @@ class ToolFileManager: with Session(self._engine, expire_on_commit=False) as session: tool_file: ToolFile | None = ( session.query(ToolFile) - .filter( + .where( ToolFile.id == tool_file_id, ) .first() diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 4787d7d79..cdfefbadb 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -29,7 +29,7 @@ class ToolLabelManager: raise ValueError("Unsupported tool type") # delete old labels - db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete() + db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete() # insert new labels for label in labels: @@ -57,7 +57,7 @@ class ToolLabelManager: labels = ( db.session.query(ToolLabelBinding.label_name) - .filter( + .where( ToolLabelBinding.tool_id == provider_id, ToolLabelBinding.tool_type == controller.provider_type.value, ) @@ -90,7 +90,7 @@ class ToolLabelManager: provider_ids.append(controller.provider_id) labels: list[ToolLabelBinding] = ( - db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all() + db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids)).all() ) tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index abbdf8de3..f286466de 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -198,7 +198,7 @@ class ToolManager: try: builtin_provider = ( db.session.query(BuiltinToolProvider) - .filter( + .where( BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.id == credential_id, ) @@ -216,7 +216,7 @@ class ToolManager: # use the default provider builtin_provider = ( db.session.query(BuiltinToolProvider) - .filter( + .where( BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == str(provider_id_entity)) | (BuiltinToolProvider.provider == provider_id_entity.provider_name), @@ -229,7 +229,7 @@ class ToolManager: else: builtin_provider = ( db.session.query(BuiltinToolProvider) - .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) + .where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) .first() ) @@ -316,7 +316,7 @@ class ToolManager: elif provider_type == ToolProviderType.WORKFLOW: workflow_provider = ( db.session.query(WorkflowToolProvider) - .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) .first() ) @@ -616,7 +616,7 @@ class ToolManager: ORDER BY tenant_id, provider, is_default DESC, created_at DESC """ ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()] - return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all() + return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all() @classmethod def list_providers_from_api( @@ -664,7 +664,7 @@ class ToolManager: # get db api providers if "api" in filters: db_api_providers: list[ApiToolProvider] = ( - db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() + db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() ) api_provider_controllers: list[dict[str, Any]] = [ @@ -687,7 +687,7 @@ class ToolManager: if "workflow" in filters: # get workflow providers workflow_providers: list[WorkflowToolProvider] = ( - db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() + db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all() ) workflow_provider_controllers: list[WorkflowToolProviderController] = [] @@ -731,7 +731,7 @@ class ToolManager: """ provider: ApiToolProvider | None = ( db.session.query(ApiToolProvider) - .filter( + .where( ApiToolProvider.id == provider_id, ApiToolProvider.tenant_id == tenant_id, ) @@ -768,7 +768,7 @@ class ToolManager: """ provider: MCPToolProvider | None = ( db.session.query(MCPToolProvider) - .filter( + .where( MCPToolProvider.server_identifier == provider_id, MCPToolProvider.tenant_id == tenant_id, ) @@ -793,7 +793,7 @@ class ToolManager: provider_name = provider provider_obj: ApiToolProvider | None = ( db.session.query(ApiToolProvider) - .filter( + .where( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == provider, ) @@ -885,7 +885,7 @@ class ToolManager: try: workflow_provider: WorkflowToolProvider | None = ( db.session.query(WorkflowToolProvider) - .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) .first() ) @@ -902,7 +902,7 @@ class ToolManager: try: api_provider: ApiToolProvider | None = ( db.session.query(ApiToolProvider) - .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) + .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) .first() ) @@ -919,7 +919,7 @@ class ToolManager: try: mcp_provider: MCPToolProvider | None = ( db.session.query(MCPToolProvider) - .filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id) + .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id) .first() ) diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 2cbc4b982..7eb4bc017 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -87,7 +87,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.dataset_id.in_(self.dataset_ids), DocumentSegment.completed_at.isnot(None), DocumentSegment.status == "completed", @@ -114,7 +114,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() document = ( db.session.query(Document) - .filter( + .where( Document.id == segment.document_id, Document.enabled == True, Document.archived == False, @@ -163,7 +163,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): ): with flask_app.app_context(): dataset = ( - db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() + db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() ) if not dataset: diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index ff1d9021c..f7689d770 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -57,7 +57,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): def _run(self, query: str) -> str: dataset = ( - db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first() + db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first() ) if not dataset: @@ -190,7 +190,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() document = ( db.session.query(DatasetDocument) # type: ignore - .filter( + .where( DatasetDocument.id == segment.document_id, DatasetDocument.enabled == True, DatasetDocument.archived == False, diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 7661e1e6a..83f5f558d 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -84,7 +84,7 @@ class WorkflowToolProviderController(ToolProviderController): """ workflow: Workflow | None = ( db.session.query(Workflow) - .filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version) + .where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version) .first() ) @@ -190,7 +190,7 @@ class WorkflowToolProviderController(ToolProviderController): db_providers: WorkflowToolProvider | None = ( db.session.query(WorkflowToolProvider) - .filter( + .where( WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == self.provider_id, ) diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 10bf8ca64..8b89c2a7a 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -142,12 +142,12 @@ class WorkflowTool(Tool): if not version: workflow = ( db.session.query(Workflow) - .filter(Workflow.app_id == app_id, Workflow.version != "draft") + .where(Workflow.app_id == app_id, Workflow.version != "draft") .order_by(Workflow.created_at.desc()) .first() ) else: - workflow = db.session.query(Workflow).filter(Workflow.app_id == app_id, Workflow.version == version).first() + workflow = db.session.query(Workflow).where(Workflow.app_id == app_id, Workflow.version == version).first() if not workflow: raise ValueError("workflow not found or not published") @@ -158,7 +158,7 @@ class WorkflowTool(Tool): """ get the app by app id """ - app = db.session.query(App).filter(App.id == app_id).first() + app = db.session.query(App).where(App.id == app_id).first() if not app: raise ValueError("app not found") diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index be8fa4d22..34b0afc75 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -228,7 +228,7 @@ class KnowledgeRetrievalNode(BaseNode): # Subquery: Count the number of available documents for each dataset subquery = ( db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count")) - .filter( + .where( Document.indexing_status == "completed", Document.enabled == True, Document.archived == False, @@ -242,8 +242,8 @@ class KnowledgeRetrievalNode(BaseNode): results = ( db.session.query(Dataset) .outerjoin(subquery, Dataset.id == subquery.c.dataset_id) - .filter(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids)) - .filter((subquery.c.available_document_count > 0) | (Dataset.provider == "external")) + .where(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids)) + .where((subquery.c.available_document_count > 0) | (Dataset.provider == "external")) .all() ) @@ -370,7 +370,7 @@ class KnowledgeRetrievalNode(BaseNode): dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore document = ( db.session.query(Document) - .filter( + .where( Document.id == segment.document_id, Document.enabled == True, Document.archived == False, @@ -415,7 +415,7 @@ class KnowledgeRetrievalNode(BaseNode): def _get_metadata_filter_condition( self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: - document_query = db.session.query(Document).filter( + document_query = db.session.query(Document).where( Document.dataset_id.in_(dataset_ids), Document.indexing_status == "completed", Document.enabled == True, @@ -493,9 +493,9 @@ class KnowledgeRetrievalNode(BaseNode): node_data.metadata_filtering_conditions and node_data.metadata_filtering_conditions.logical_operator == "and" ): # type: ignore - document_query = document_query.filter(and_(*filters)) + document_query = document_query.where(and_(*filters)) else: - document_query = document_query.filter(or_(*filters)) + document_query = document_query.where(or_(*filters)) documents = document_query.all() # group by dataset_id metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore @@ -507,7 +507,7 @@ class KnowledgeRetrievalNode(BaseNode): self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData ) -> list[dict[str, Any]]: # get all metadata field - metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all() + metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all() all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] if node_data.metadata_model_config is None: raise ValueError("metadata_model_config is required") diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index cb48bd92a..dc50ca8d9 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -22,7 +22,7 @@ def handle(sender, **kwargs): document = ( db.session.query(Document) - .filter( + .where( Document.id == document_id, Document.dataset_id == dataset_id, ) diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index 14396e992..b8b5a89dc 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -13,7 +13,7 @@ def handle(sender, **kwargs): dataset_ids = get_dataset_ids_from_model_config(app_model_config) - app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() + app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all() removed_dataset_ids: set[str] = set() if not app_dataset_joins: @@ -27,7 +27,7 @@ def handle(sender, **kwargs): if removed_dataset_ids: for dataset_id in removed_dataset_ids: - db.session.query(AppDatasetJoin).filter( + db.session.query(AppDatasetJoin).where( AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id ).delete() diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index dd2efed94..cf4ba6983 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -15,7 +15,7 @@ def handle(sender, **kwargs): published_workflow = cast(Workflow, published_workflow) dataset_ids = get_dataset_ids_from_workflow(published_workflow) - app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() + app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all() removed_dataset_ids: set[str] = set() if not app_dataset_joins: @@ -29,7 +29,7 @@ def handle(sender, **kwargs): if removed_dataset_ids: for dataset_id in removed_dataset_ids: - db.session.query(AppDatasetJoin).filter( + db.session.query(AppDatasetJoin).where( AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id ).delete() diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 11d1856ac..9b18e25ea 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -40,9 +40,9 @@ def load_user_from_request(request_from_flask_login): if workspace_id: tenant_account_join = ( db.session.query(Tenant, TenantAccountJoin) - .filter(Tenant.id == workspace_id) - .filter(TenantAccountJoin.tenant_id == Tenant.id) - .filter(TenantAccountJoin.role == "owner") + .where(Tenant.id == workspace_id) + .where(TenantAccountJoin.tenant_id == Tenant.id) + .where(TenantAccountJoin.role == "owner") .one_or_none() ) if tenant_account_join: @@ -70,7 +70,7 @@ def load_user_from_request(request_from_flask_login): end_user_id = decoded.get("end_user_id") if not end_user_id: raise Unauthorized("Invalid Authorization token.") - end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first() + end_user = db.session.query(EndUser).where(EndUser.id == decoded["end_user_id"]).first() if not end_user: raise NotFound("End user not found.") return end_user @@ -78,12 +78,12 @@ def load_user_from_request(request_from_flask_login): server_code = request.view_args.get("server_code") if request.view_args else None if not server_code: raise Unauthorized("Invalid Authorization token.") - app_mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first() + app_mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first() if not app_mcp_server: raise NotFound("App MCP server not found.") end_user = ( db.session.query(EndUser) - .filter(EndUser.external_user_id == app_mcp_server.id, EndUser.type == "mcp") + .where(EndUser.external_user_id == app_mcp_server.id, EndUser.type == "mcp") .first() ) if not end_user: diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index c974dbb70..512a9cb60 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -261,13 +261,11 @@ def _build_from_tool_file( transfer_method: FileTransferMethod, strict_type_validation: bool = False, ) -> File: - tool_file = ( - db.session.query(ToolFile) - .filter( + tool_file = db.session.scalar( + select(ToolFile).where( ToolFile.id == mapping.get("tool_file_id"), ToolFile.tenant_id == tenant_id, ) - .first() ) if tool_file is None: @@ -275,7 +273,7 @@ def _build_from_tool_file( extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" - detected_file_type = _standardize_file_type(extension="." + extension, mime_type=tool_file.mimetype) + detected_file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype) specified_type = mapping.get("type") diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 78f827584..987c5d713 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -3,6 +3,7 @@ from typing import Any import requests from flask_login import current_user +from sqlalchemy import select from extensions.ext_database import db from libs.datetime_utils import naive_utc_now @@ -61,16 +62,12 @@ class NotionOAuth(OAuthDataSource): "total": len(pages), } # save data source binding - data_source_binding = ( - db.session.query(DataSourceOauthBinding) - .filter( - db.and_( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.access_token == access_token, - ) + data_source_binding = db.session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, ) - .first() ) if data_source_binding: data_source_binding.source_info = source_info @@ -101,16 +98,12 @@ class NotionOAuth(OAuthDataSource): "total": len(pages), } # save data source binding - data_source_binding = ( - db.session.query(DataSourceOauthBinding) - .filter( - db.and_( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.access_token == access_token, - ) + data_source_binding = db.session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, ) - .first() ) if data_source_binding: data_source_binding.source_info = source_info @@ -129,18 +122,15 @@ class NotionOAuth(OAuthDataSource): def sync_data_source(self, binding_id: str): # save data source binding - data_source_binding = ( - db.session.query(DataSourceOauthBinding) - .filter( - db.and_( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.id == binding_id, - DataSourceOauthBinding.disabled == False, - ) + data_source_binding = db.session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.id == binding_id, + DataSourceOauthBinding.disabled == False, ) - .first() ) + if data_source_binding: # get all authorized pages pages = self.get_authorized_pages(data_source_binding.access_token) diff --git a/api/models/account.py b/api/models/account.py index 01d1625db..d63c5d7fb 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -4,7 +4,7 @@ from datetime import datetime from typing import Optional, cast from flask_login import UserMixin # type: ignore -from sqlalchemy import func +from sqlalchemy import func, select from sqlalchemy.orm import Mapped, mapped_column, reconstructor from models.base import Base @@ -119,7 +119,7 @@ class Account(UserMixin, Base): @current_tenant.setter def current_tenant(self, tenant: "Tenant"): - ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first() + ta = db.session.scalar(select(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).limit(1)) if ta: self.role = TenantAccountRole(ta.role) self._current_tenant = tenant @@ -135,9 +135,9 @@ class Account(UserMixin, Base): tuple[Tenant, TenantAccountJoin], ( db.session.query(Tenant, TenantAccountJoin) - .filter(Tenant.id == tenant_id) - .filter(TenantAccountJoin.tenant_id == Tenant.id) - .filter(TenantAccountJoin.account_id == self.id) + .where(Tenant.id == tenant_id) + .where(TenantAccountJoin.tenant_id == Tenant.id) + .where(TenantAccountJoin.account_id == self.id) .one_or_none() ), ) @@ -161,11 +161,11 @@ class Account(UserMixin, Base): def get_by_openid(cls, provider: str, open_id: str): account_integrate = ( db.session.query(AccountIntegrate) - .filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) + .where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) .one_or_none() ) if account_integrate: - return db.session.query(Account).filter(Account.id == account_integrate.account_id).one_or_none() + return db.session.query(Account).where(Account.id == account_integrate.account_id).one_or_none() return None # check current_user.current_tenant.current_role in ['admin', 'owner'] @@ -211,7 +211,7 @@ class Tenant(Base): def get_accounts(self) -> list[Account]: return ( db.session.query(Account) - .filter(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id) + .where(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id) .all() ) diff --git a/api/models/dataset.py b/api/models/dataset.py index d5a13efb9..d87754021 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -12,7 +12,7 @@ from datetime import datetime from json import JSONDecodeError from typing import Any, Optional, cast -from sqlalchemy import func +from sqlalchemy import func, select from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column @@ -68,7 +68,7 @@ class Dataset(Base): @property def dataset_keyword_table(self): dataset_keyword_table = ( - db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == self.id).first() + db.session.query(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id).first() ) if dataset_keyword_table: return dataset_keyword_table @@ -95,7 +95,7 @@ class Dataset(Base): def latest_process_rule(self): return ( db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.dataset_id == self.id) + .where(DatasetProcessRule.dataset_id == self.id) .order_by(DatasetProcessRule.created_at.desc()) .first() ) @@ -104,19 +104,19 @@ class Dataset(Base): def app_count(self): return ( db.session.query(func.count(AppDatasetJoin.id)) - .filter(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id) + .where(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id) .scalar() ) @property def document_count(self): - return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar() + return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar() @property def available_document_count(self): return ( db.session.query(func.count(Document.id)) - .filter( + .where( Document.dataset_id == self.id, Document.indexing_status == "completed", Document.enabled == True, @@ -129,7 +129,7 @@ class Dataset(Base): def available_segment_count(self): return ( db.session.query(func.count(DocumentSegment.id)) - .filter( + .where( DocumentSegment.dataset_id == self.id, DocumentSegment.status == "completed", DocumentSegment.enabled == True, @@ -142,13 +142,13 @@ class Dataset(Base): return ( db.session.query(Document) .with_entities(func.coalesce(func.sum(Document.word_count), 0)) - .filter(Document.dataset_id == self.id) + .where(Document.dataset_id == self.id) .scalar() ) @property def doc_form(self): - document = db.session.query(Document).filter(Document.dataset_id == self.id).first() + document = db.session.query(Document).where(Document.dataset_id == self.id).first() if document: return document.doc_form return None @@ -169,7 +169,7 @@ class Dataset(Base): tags = ( db.session.query(Tag) .join(TagBinding, Tag.id == TagBinding.tag_id) - .filter( + .where( TagBinding.target_id == self.id, TagBinding.tenant_id == self.tenant_id, Tag.tenant_id == self.tenant_id, @@ -185,14 +185,14 @@ class Dataset(Base): if self.provider != "external": return None external_knowledge_binding = ( - db.session.query(ExternalKnowledgeBindings).filter(ExternalKnowledgeBindings.dataset_id == self.id).first() + db.session.query(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id).first() ) if not external_knowledge_binding: return None - external_knowledge_api = ( - db.session.query(ExternalKnowledgeApis) - .filter(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id) - .first() + external_knowledge_api = db.session.scalar( + select(ExternalKnowledgeApis).where( + ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id + ) ) if not external_knowledge_api: return None @@ -205,7 +205,7 @@ class Dataset(Base): @property def doc_metadata(self): - dataset_metadatas = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id == self.id).all() + dataset_metadatas = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id).all() doc_metadata = [ { @@ -408,7 +408,7 @@ class Document(Base): data_source_info_dict = json.loads(self.data_source_info) file_detail = ( db.session.query(UploadFile) - .filter(UploadFile.id == data_source_info_dict["upload_file_id"]) + .where(UploadFile.id == data_source_info_dict["upload_file_id"]) .one_or_none() ) if file_detail: @@ -441,24 +441,24 @@ class Document(Base): @property def dataset(self): - return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none() + return db.session.query(Dataset).where(Dataset.id == self.dataset_id).one_or_none() @property def segment_count(self): - return db.session.query(DocumentSegment).filter(DocumentSegment.document_id == self.id).count() + return db.session.query(DocumentSegment).where(DocumentSegment.document_id == self.id).count() @property def hit_count(self): return ( db.session.query(DocumentSegment) .with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0)) - .filter(DocumentSegment.document_id == self.id) + .where(DocumentSegment.document_id == self.id) .scalar() ) @property def uploader(self): - user = db.session.query(Account).filter(Account.id == self.created_by).first() + user = db.session.query(Account).where(Account.id == self.created_by).first() return user.name if user else None @property @@ -475,7 +475,7 @@ class Document(Base): document_metadatas = ( db.session.query(DatasetMetadata) .join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id) - .filter( + .where( DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id ) .all() @@ -687,26 +687,26 @@ class DocumentSegment(Base): @property def dataset(self): - return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first() + return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id)) @property def document(self): - return db.session.query(Document).filter(Document.id == self.document_id).first() + return db.session.scalar(select(Document).where(Document.id == self.document_id)) @property def previous_segment(self): - return ( - db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1) - .first() + return db.session.scalar( + select(DocumentSegment).where( + DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1 + ) ) @property def next_segment(self): - return ( - db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1) - .first() + return db.session.scalar( + select(DocumentSegment).where( + DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1 + ) ) @property @@ -717,7 +717,7 @@ class DocumentSegment(Base): if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: child_chunks = ( db.session.query(ChildChunk) - .filter(ChildChunk.segment_id == self.id) + .where(ChildChunk.segment_id == self.id) .order_by(ChildChunk.position.asc()) .all() ) @@ -734,7 +734,7 @@ class DocumentSegment(Base): if rules.parent_mode: child_chunks = ( db.session.query(ChildChunk) - .filter(ChildChunk.segment_id == self.id) + .where(ChildChunk.segment_id == self.id) .order_by(ChildChunk.position.asc()) .all() ) @@ -825,15 +825,15 @@ class ChildChunk(Base): @property def dataset(self): - return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first() + return db.session.query(Dataset).where(Dataset.id == self.dataset_id).first() @property def document(self): - return db.session.query(Document).filter(Document.id == self.document_id).first() + return db.session.query(Document).where(Document.id == self.document_id).first() @property def segment(self): - return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first() + return db.session.query(DocumentSegment).where(DocumentSegment.id == self.segment_id).first() class AppDatasetJoin(Base): @@ -1044,11 +1044,11 @@ class ExternalKnowledgeApis(Base): def dataset_bindings(self): external_knowledge_bindings = ( db.session.query(ExternalKnowledgeBindings) - .filter(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) + .where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) .all() ) dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] - datasets = db.session.query(Dataset).filter(Dataset.id.in_(dataset_ids)).all() + datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all() dataset_bindings = [] for dataset in datasets: dataset_bindings.append({"id": dataset.id, "name": dataset.name}) diff --git a/api/models/model.py b/api/models/model.py index b8e8b7801..a78a91ebd 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -113,13 +113,13 @@ class App(Base): @property def site(self): - site = db.session.query(Site).filter(Site.app_id == self.id).first() + site = db.session.query(Site).where(Site.app_id == self.id).first() return site @property def app_model_config(self): if self.app_model_config_id: - return db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() + return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() return None @@ -128,7 +128,7 @@ class App(Base): if self.workflow_id: from .workflow import Workflow - return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() + return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first() return None @@ -138,7 +138,7 @@ class App(Base): @property def tenant(self): - tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() return tenant @property @@ -282,7 +282,7 @@ class App(Base): tags = ( db.session.query(Tag) .join(TagBinding, Tag.id == TagBinding.tag_id) - .filter( + .where( TagBinding.target_id == self.id, TagBinding.tenant_id == self.tenant_id, Tag.tenant_id == self.tenant_id, @@ -296,7 +296,7 @@ class App(Base): @property def author_name(self): if self.created_by: - account = db.session.query(Account).filter(Account.id == self.created_by).first() + account = db.session.query(Account).where(Account.id == self.created_by).first() if account: return account.name @@ -338,7 +338,7 @@ class AppModelConfig(Base): @property def app(self): - app = db.session.query(App).filter(App.id == self.app_id).first() + app = db.session.query(App).where(App.id == self.app_id).first() return app @property @@ -372,7 +372,7 @@ class AppModelConfig(Base): @property def annotation_reply_dict(self) -> dict: annotation_setting = ( - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == self.app_id).first() + db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first() ) if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail @@ -577,7 +577,7 @@ class RecommendedApp(Base): @property def app(self): - app = db.session.query(App).filter(App.id == self.app_id).first() + app = db.session.query(App).where(App.id == self.app_id).first() return app @@ -601,12 +601,12 @@ class InstalledApp(Base): @property def app(self): - app = db.session.query(App).filter(App.id == self.app_id).first() + app = db.session.query(App).where(App.id == self.app_id).first() return app @property def tenant(self): - tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() return tenant @@ -714,7 +714,7 @@ class Conversation(Base): model_config["configs"] = override_model_configs else: app_model_config = ( - db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() + db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() ) if app_model_config: model_config = app_model_config.to_dict() @@ -737,21 +737,21 @@ class Conversation(Base): @property def annotated(self): - return db.session.query(MessageAnnotation).filter(MessageAnnotation.conversation_id == self.id).count() > 0 + return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).count() > 0 @property def annotation(self): - return db.session.query(MessageAnnotation).filter(MessageAnnotation.conversation_id == self.id).first() + return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).first() @property def message_count(self): - return db.session.query(Message).filter(Message.conversation_id == self.id).count() + return db.session.query(Message).where(Message.conversation_id == self.id).count() @property def user_feedback_stats(self): like = ( db.session.query(MessageFeedback) - .filter( + .where( MessageFeedback.conversation_id == self.id, MessageFeedback.from_source == "user", MessageFeedback.rating == "like", @@ -761,7 +761,7 @@ class Conversation(Base): dislike = ( db.session.query(MessageFeedback) - .filter( + .where( MessageFeedback.conversation_id == self.id, MessageFeedback.from_source == "user", MessageFeedback.rating == "dislike", @@ -775,7 +775,7 @@ class Conversation(Base): def admin_feedback_stats(self): like = ( db.session.query(MessageFeedback) - .filter( + .where( MessageFeedback.conversation_id == self.id, MessageFeedback.from_source == "admin", MessageFeedback.rating == "like", @@ -785,7 +785,7 @@ class Conversation(Base): dislike = ( db.session.query(MessageFeedback) - .filter( + .where( MessageFeedback.conversation_id == self.id, MessageFeedback.from_source == "admin", MessageFeedback.rating == "dislike", @@ -797,7 +797,7 @@ class Conversation(Base): @property def status_count(self): - messages = db.session.query(Message).filter(Message.conversation_id == self.id).all() + messages = db.session.query(Message).where(Message.conversation_id == self.id).all() status_counts = { WorkflowExecutionStatus.RUNNING: 0, WorkflowExecutionStatus.SUCCEEDED: 0, @@ -824,19 +824,19 @@ class Conversation(Base): def first_message(self): return ( db.session.query(Message) - .filter(Message.conversation_id == self.id) + .where(Message.conversation_id == self.id) .order_by(Message.created_at.asc()) .first() ) @property def app(self): - return db.session.query(App).filter(App.id == self.app_id).first() + return db.session.query(App).where(App.id == self.app_id).first() @property def from_end_user_session_id(self): if self.from_end_user_id: - end_user = db.session.query(EndUser).filter(EndUser.id == self.from_end_user_id).first() + end_user = db.session.query(EndUser).where(EndUser.id == self.from_end_user_id).first() if end_user: return end_user.session_id @@ -845,7 +845,7 @@ class Conversation(Base): @property def from_account_name(self): if self.from_account_id: - account = db.session.query(Account).filter(Account.id == self.from_account_id).first() + account = db.session.query(Account).where(Account.id == self.from_account_id).first() if account: return account.name @@ -1040,7 +1040,7 @@ class Message(Base): def user_feedback(self): feedback = ( db.session.query(MessageFeedback) - .filter(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user") + .where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user") .first() ) return feedback @@ -1049,30 +1049,30 @@ class Message(Base): def admin_feedback(self): feedback = ( db.session.query(MessageFeedback) - .filter(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin") + .where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin") .first() ) return feedback @property def feedbacks(self): - feedbacks = db.session.query(MessageFeedback).filter(MessageFeedback.message_id == self.id).all() + feedbacks = db.session.query(MessageFeedback).where(MessageFeedback.message_id == self.id).all() return feedbacks @property def annotation(self): - annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == self.id).first() + annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == self.id).first() return annotation @property def annotation_hit_history(self): annotation_history = ( - db.session.query(AppAnnotationHitHistory).filter(AppAnnotationHitHistory.message_id == self.id).first() + db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id).first() ) if annotation_history: annotation = ( db.session.query(MessageAnnotation) - .filter(MessageAnnotation.id == annotation_history.annotation_id) + .where(MessageAnnotation.id == annotation_history.annotation_id) .first() ) return annotation @@ -1080,11 +1080,9 @@ class Message(Base): @property def app_model_config(self): - conversation = db.session.query(Conversation).filter(Conversation.id == self.conversation_id).first() + conversation = db.session.query(Conversation).where(Conversation.id == self.conversation_id).first() if conversation: - return ( - db.session.query(AppModelConfig).filter(AppModelConfig.id == conversation.app_model_config_id).first() - ) + return db.session.query(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id).first() return None @@ -1100,7 +1098,7 @@ class Message(Base): def agent_thoughts(self): return ( db.session.query(MessageAgentThought) - .filter(MessageAgentThought.message_id == self.id) + .where(MessageAgentThought.message_id == self.id) .order_by(MessageAgentThought.position.asc()) .all() ) @@ -1113,8 +1111,8 @@ class Message(Base): def message_files(self): from factories import file_factory - message_files = db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all() - current_app = db.session.query(App).filter(App.id == self.app_id).first() + message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all() + current_app = db.session.query(App).where(App.id == self.app_id).first() if not current_app: raise ValueError(f"App {self.app_id} not found") @@ -1178,7 +1176,7 @@ class Message(Base): if self.workflow_run_id: from .workflow import WorkflowRun - return db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first() + return db.session.query(WorkflowRun).where(WorkflowRun.id == self.workflow_run_id).first() return None @@ -1253,7 +1251,7 @@ class MessageFeedback(Base): @property def from_account(self): - account = db.session.query(Account).filter(Account.id == self.from_account_id).first() + account = db.session.query(Account).where(Account.id == self.from_account_id).first() return account def to_dict(self): @@ -1335,12 +1333,12 @@ class MessageAnnotation(Base): @property def account(self): - account = db.session.query(Account).filter(Account.id == self.account_id).first() + account = db.session.query(Account).where(Account.id == self.account_id).first() return account @property def annotation_create_account(self): - account = db.session.query(Account).filter(Account.id == self.account_id).first() + account = db.session.query(Account).where(Account.id == self.account_id).first() return account @@ -1371,14 +1369,14 @@ class AppAnnotationHitHistory(Base): account = ( db.session.query(Account) .join(MessageAnnotation, MessageAnnotation.account_id == Account.id) - .filter(MessageAnnotation.id == self.annotation_id) + .where(MessageAnnotation.id == self.annotation_id) .first() ) return account @property def annotation_create_account(self): - account = db.session.query(Account).filter(Account.id == self.account_id).first() + account = db.session.query(Account).where(Account.id == self.account_id).first() return account @@ -1404,7 +1402,7 @@ class AppAnnotationSetting(Base): collection_binding_detail = ( db.session.query(DatasetCollectionBinding) - .filter(DatasetCollectionBinding.id == self.collection_binding_id) + .where(DatasetCollectionBinding.id == self.collection_binding_id) .first() ) return collection_binding_detail @@ -1470,7 +1468,7 @@ class AppMCPServer(Base): def generate_server_code(n): while True: result = generate_string(n) - while db.session.query(AppMCPServer).filter(AppMCPServer.server_code == result).count() > 0: + while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0: result = generate_string(n) return result @@ -1527,7 +1525,7 @@ class Site(Base): def generate_code(n): while True: result = generate_string(n) - while db.session.query(Site).filter(Site.code == result).count() > 0: + while db.session.query(Site).where(Site.code == result).count() > 0: result = generate_string(n) return result @@ -1558,7 +1556,7 @@ class ApiToken(Base): def generate_api_key(prefix, n): while True: result = prefix + generate_string(n) - if db.session.query(ApiToken).filter(ApiToken.token == result).count() > 0: + if db.session.query(ApiToken).where(ApiToken.token == result).count() > 0: continue return result diff --git a/api/models/tools.py b/api/models/tools.py index 8c91e91f0..68f4211e5 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -153,11 +153,11 @@ class ApiToolProvider(Base): def user(self) -> Account | None: if not self.user_id: return None - return db.session.query(Account).filter(Account.id == self.user_id).first() + return db.session.query(Account).where(Account.id == self.user_id).first() @property def tenant(self) -> Tenant | None: - return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() class ToolLabelBinding(Base): @@ -223,11 +223,11 @@ class WorkflowToolProvider(Base): @property def user(self) -> Account | None: - return db.session.query(Account).filter(Account.id == self.user_id).first() + return db.session.query(Account).where(Account.id == self.user_id).first() @property def tenant(self) -> Tenant | None: - return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() @property def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]: @@ -235,7 +235,7 @@ class WorkflowToolProvider(Base): @property def app(self) -> App | None: - return db.session.query(App).filter(App.id == self.app_id).first() + return db.session.query(App).where(App.id == self.app_id).first() class MCPToolProvider(Base): @@ -280,11 +280,11 @@ class MCPToolProvider(Base): ) def load_user(self) -> Account | None: - return db.session.query(Account).filter(Account.id == self.user_id).first() + return db.session.query(Account).where(Account.id == self.user_id).first() @property def tenant(self) -> Tenant | None: - return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() @property def credentials(self) -> dict: diff --git a/api/models/web.py b/api/models/web.py index bcc95ddbc..ce00f4010 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -26,7 +26,7 @@ class SavedMessage(Base): @property def message(self): - return db.session.query(Message).filter(Message.id == self.message_id).first() + return db.session.query(Message).where(Message.id == self.message_id).first() class PinnedConversation(Base): diff --git a/api/models/workflow.py b/api/models/workflow.py index 124fb3bb4..79d96e42d 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -343,7 +343,7 @@ class Workflow(Base): return ( db.session.query(WorkflowToolProvider) - .filter(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id) + .where(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id) .count() > 0 ) @@ -549,12 +549,12 @@ class WorkflowRun(Base): from models.model import Message return ( - db.session.query(Message).filter(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first() + db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first() ) @property def workflow(self): - return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() + return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first() def to_dict(self): return { diff --git a/api/schedule/clean_embedding_cache_task.py b/api/schedule/clean_embedding_cache_task.py index 9efe120b7..024e3d6f5 100644 --- a/api/schedule/clean_embedding_cache_task.py +++ b/api/schedule/clean_embedding_cache_task.py @@ -21,7 +21,7 @@ def clean_embedding_cache_task(): try: embedding_ids = ( db.session.query(Embedding.id) - .filter(Embedding.created_at < thirty_days_ago) + .where(Embedding.created_at < thirty_days_ago) .order_by(Embedding.created_at.desc()) .limit(100) .all() diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index d02bc81f3..a6851e36e 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -36,7 +36,7 @@ def clean_messages(): # Main query with join and filter messages = ( db.session.query(Message) - .filter(Message.created_at < plan_sandbox_clean_message_day) + .where(Message.created_at < plan_sandbox_clean_message_day) .order_by(Message.created_at.desc()) .limit(100) .all() @@ -66,25 +66,25 @@ def clean_messages(): plan = plan_cache.decode() if plan == "sandbox": # clean related message - db.session.query(MessageFeedback).filter(MessageFeedback.message_id == message.id).delete( + db.session.query(MessageFeedback).where(MessageFeedback.message_id == message.id).delete( synchronize_session=False ) - db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == message.id).delete( + db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message.id).delete( synchronize_session=False ) - db.session.query(MessageChain).filter(MessageChain.message_id == message.id).delete( + db.session.query(MessageChain).where(MessageChain.message_id == message.id).delete( synchronize_session=False ) - db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).delete( + db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message.id).delete( synchronize_session=False ) - db.session.query(MessageFile).filter(MessageFile.message_id == message.id).delete( + db.session.query(MessageFile).where(MessageFile.message_id == message.id).delete( synchronize_session=False ) - db.session.query(SavedMessage).filter(SavedMessage.message_id == message.id).delete( + db.session.query(SavedMessage).where(SavedMessage.message_id == message.id).delete( synchronize_session=False ) - db.session.query(Message).filter(Message.id == message.id).delete() + db.session.query(Message).where(Message.id == message.id).delete() db.session.commit() end_at = time.perf_counter() click.echo(click.style("Cleaned messages from db success latency: {}".format(end_at - start_at), fg="green")) diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index be228a6d9..72e2e73e6 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -27,7 +27,7 @@ def clean_unused_datasets_task(): # Subquery for counting new documents document_subquery_new = ( db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) - .filter( + .where( Document.indexing_status == "completed", Document.enabled == True, Document.archived == False, @@ -40,7 +40,7 @@ def clean_unused_datasets_task(): # Subquery for counting old documents document_subquery_old = ( db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) - .filter( + .where( Document.indexing_status == "completed", Document.enabled == True, Document.archived == False, @@ -55,7 +55,7 @@ def clean_unused_datasets_task(): select(Dataset) .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) - .filter( + .where( Dataset.created_at < plan_sandbox_clean_day, func.coalesce(document_subquery_new.c.document_count, 0) == 0, func.coalesce(document_subquery_old.c.document_count, 0) > 0, @@ -72,7 +72,7 @@ def clean_unused_datasets_task(): for dataset in datasets: dataset_query = ( db.session.query(DatasetQuery) - .filter(DatasetQuery.created_at > plan_sandbox_clean_day, DatasetQuery.dataset_id == dataset.id) + .where(DatasetQuery.created_at > plan_sandbox_clean_day, DatasetQuery.dataset_id == dataset.id) .all() ) if not dataset_query or len(dataset_query) == 0: @@ -80,7 +80,7 @@ def clean_unused_datasets_task(): # add auto disable log documents = ( db.session.query(Document) - .filter( + .where( Document.dataset_id == dataset.id, Document.enabled == True, Document.archived == False, @@ -111,7 +111,7 @@ def clean_unused_datasets_task(): # Subquery for counting new documents document_subquery_new = ( db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) - .filter( + .where( Document.indexing_status == "completed", Document.enabled == True, Document.archived == False, @@ -124,7 +124,7 @@ def clean_unused_datasets_task(): # Subquery for counting old documents document_subquery_old = ( db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) - .filter( + .where( Document.indexing_status == "completed", Document.enabled == True, Document.archived == False, @@ -139,7 +139,7 @@ def clean_unused_datasets_task(): select(Dataset) .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) - .filter( + .where( Dataset.created_at < plan_pro_clean_day, func.coalesce(document_subquery_new.c.document_count, 0) == 0, func.coalesce(document_subquery_old.c.document_count, 0) > 0, @@ -155,7 +155,7 @@ def clean_unused_datasets_task(): for dataset in datasets: dataset_query = ( db.session.query(DatasetQuery) - .filter(DatasetQuery.created_at > plan_pro_clean_day, DatasetQuery.dataset_id == dataset.id) + .where(DatasetQuery.created_at > plan_pro_clean_day, DatasetQuery.dataset_id == dataset.id) .all() ) if not dataset_query or len(dataset_query) == 0: diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index 8a02278de..91953354e 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -20,7 +20,7 @@ def create_tidb_serverless_task(): try: # check the number of idle tidb serverless idle_tidb_serverless_number = ( - db.session.query(TidbAuthBinding).filter(TidbAuthBinding.active == False).count() + db.session.query(TidbAuthBinding).where(TidbAuthBinding.active == False).count() ) if idle_tidb_serverless_number >= tidb_serverless_number: break diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index 12e4f6ebf..5911c98b0 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -30,7 +30,7 @@ def mail_clean_document_notify_task(): # send document clean notify mail try: dataset_auto_disable_logs = ( - db.session.query(DatasetAutoDisableLog).filter(DatasetAutoDisableLog.notified == False).all() + db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False).all() ) # group by tenant_id dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) @@ -45,7 +45,7 @@ def mail_clean_document_notify_task(): if plan != "sandbox": knowledge_details = [] # check tenant - tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first() + tenant = db.session.query(Tenant).where(Tenant.id == tenant_id).first() if not tenant: continue # check current owner @@ -54,7 +54,7 @@ def mail_clean_document_notify_task(): ) if not current_owner_join: continue - account = db.session.query(Account).filter(Account.id == current_owner_join.account_id).first() + account = db.session.query(Account).where(Account.id == current_owner_join.account_id).first() if not account: continue @@ -67,7 +67,7 @@ def mail_clean_document_notify_task(): ) for dataset_id, document_ids in dataset_auto_dataset_map.items(): - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if dataset: document_count = len(document_ids) knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index ce4ecb6e7..4d6c1f187 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -17,7 +17,7 @@ def update_tidb_serverless_status_task(): # check the number of idle tidb serverless tidb_serverless_list = ( db.session.query(TidbAuthBinding) - .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") + .where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") .all() ) if len(tidb_serverless_list) == 0: diff --git a/api/services/account_service.py b/api/services/account_service.py index 4c4510395..59bffa873 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -643,7 +643,7 @@ class AccountService: ) ) - account = db.session.query(Account).filter(Account.email == email).first() + account = db.session.query(Account).where(Account.email == email).first() if not account: return None @@ -900,7 +900,7 @@ class TenantService: return ( db.session.query(Tenant) .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) - .filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL) + .where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL) .all() ) @@ -929,7 +929,7 @@ class TenantService: tenant_account_join = ( db.session.query(TenantAccountJoin) .join(Tenant, TenantAccountJoin.tenant_id == Tenant.id) - .filter( + .where( TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id == tenant_id, Tenant.status == TenantStatus.NORMAL, @@ -940,7 +940,7 @@ class TenantService: if not tenant_account_join: raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") else: - db.session.query(TenantAccountJoin).filter( + db.session.query(TenantAccountJoin).where( TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id ).update({"current": False}) tenant_account_join.current = True @@ -955,7 +955,7 @@ class TenantService: db.session.query(Account, TenantAccountJoin.role) .select_from(Account) .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) - .filter(TenantAccountJoin.tenant_id == tenant.id) + .where(TenantAccountJoin.tenant_id == tenant.id) ) # Initialize an empty list to store the updated accounts @@ -974,8 +974,8 @@ class TenantService: db.session.query(Account, TenantAccountJoin.role) .select_from(Account) .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) - .filter(TenantAccountJoin.tenant_id == tenant.id) - .filter(TenantAccountJoin.role == "dataset_operator") + .where(TenantAccountJoin.tenant_id == tenant.id) + .where(TenantAccountJoin.role == "dataset_operator") ) # Initialize an empty list to store the updated accounts @@ -995,9 +995,7 @@ class TenantService: return ( db.session.query(TenantAccountJoin) - .filter( - TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles]) - ) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles])) .first() is not None ) @@ -1007,7 +1005,7 @@ class TenantService: """Get the role of the current account for a given tenant""" join = ( db.session.query(TenantAccountJoin) - .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) .first() ) return TenantAccountRole(join.role) if join else None @@ -1274,7 +1272,7 @@ class RegisterService: tenant = ( db.session.query(Tenant) - .filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal") + .where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal") .first() ) @@ -1284,7 +1282,7 @@ class RegisterService: tenant_account = ( db.session.query(Account, TenantAccountJoin.role) .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) - .filter(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id) + .where(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id) .first() ) diff --git a/api/services/agent_service.py b/api/services/agent_service.py index 503b31ede..7c6df2428 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -25,7 +25,7 @@ class AgentService: conversation: Conversation | None = ( db.session.query(Conversation) - .filter( + .where( Conversation.id == conversation_id, Conversation.app_id == app_model.id, ) @@ -37,7 +37,7 @@ class AgentService: message: Optional[Message] = ( db.session.query(Message) - .filter( + .where( Message.id == message_id, Message.conversation_id == conversation_id, ) @@ -52,12 +52,10 @@ class AgentService: if conversation.from_end_user_id: # only select name field executor = ( - db.session.query(EndUser, EndUser.name).filter(EndUser.id == conversation.from_end_user_id).first() + db.session.query(EndUser, EndUser.name).where(EndUser.id == conversation.from_end_user_id).first() ) else: - executor = ( - db.session.query(Account, Account.name).filter(Account.id == conversation.from_account_id).first() - ) + executor = db.session.query(Account, Account.name).where(Account.id == conversation.from_account_id).first() if executor: executor = executor.name diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 8c950abc2..7cb0b4651 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -26,7 +26,7 @@ class AppAnnotationService: # get app info app = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) @@ -35,7 +35,7 @@ class AppAnnotationService: if args.get("message_id"): message_id = str(args["message_id"]) # get message info - message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app.id).first() + message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app.id).first() if not message: raise NotFound("Message Not Exists.") @@ -61,9 +61,7 @@ class AppAnnotationService: db.session.add(annotation) db.session.commit() # if annotation reply is enabled , add annotation to index - annotation_setting = ( - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() - ) + annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() if annotation_setting: add_annotation_to_index_task.delay( annotation.id, @@ -117,7 +115,7 @@ class AppAnnotationService: # get app info app = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) @@ -126,8 +124,8 @@ class AppAnnotationService: if keyword: stmt = ( select(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) - .filter( + .where(MessageAnnotation.app_id == app_id) + .where( or_( MessageAnnotation.question.ilike("%{}%".format(keyword)), MessageAnnotation.content.ilike("%{}%".format(keyword)), @@ -138,7 +136,7 @@ class AppAnnotationService: else: stmt = ( select(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) + .where(MessageAnnotation.app_id == app_id) .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) ) annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False) @@ -149,7 +147,7 @@ class AppAnnotationService: # get app info app = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) @@ -157,7 +155,7 @@ class AppAnnotationService: raise NotFound("App not found") annotations = ( db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) + .where(MessageAnnotation.app_id == app_id) .order_by(MessageAnnotation.created_at.desc()) .all() ) @@ -168,7 +166,7 @@ class AppAnnotationService: # get app info app = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) @@ -181,9 +179,7 @@ class AppAnnotationService: db.session.add(annotation) db.session.commit() # if annotation reply is enabled , add annotation to index - annotation_setting = ( - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() - ) + annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() if annotation_setting: add_annotation_to_index_task.delay( annotation.id, @@ -199,14 +195,14 @@ class AppAnnotationService: # get app info app = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) if not app: raise NotFound("App not found") - annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() if not annotation: raise NotFound("Annotation not found") @@ -217,7 +213,7 @@ class AppAnnotationService: db.session.commit() # if annotation reply is enabled , add annotation to index app_annotation_setting = ( - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() ) if app_annotation_setting: @@ -236,14 +232,14 @@ class AppAnnotationService: # get app info app = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) if not app: raise NotFound("App not found") - annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() if not annotation: raise NotFound("Annotation not found") @@ -252,7 +248,7 @@ class AppAnnotationService: annotation_hit_histories = ( db.session.query(AppAnnotationHitHistory) - .filter(AppAnnotationHitHistory.annotation_id == annotation_id) + .where(AppAnnotationHitHistory.annotation_id == annotation_id) .all() ) if annotation_hit_histories: @@ -262,7 +258,7 @@ class AppAnnotationService: db.session.commit() # if annotation reply is enabled , delete annotation index app_annotation_setting = ( - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() ) if app_annotation_setting: @@ -275,7 +271,7 @@ class AppAnnotationService: # get app info app = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) @@ -314,21 +310,21 @@ class AppAnnotationService: # get app info app = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) if not app: raise NotFound("App not found") - annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() if not annotation: raise NotFound("Annotation not found") stmt = ( select(AppAnnotationHitHistory) - .filter( + .where( AppAnnotationHitHistory.app_id == app_id, AppAnnotationHitHistory.annotation_id == annotation_id, ) @@ -341,7 +337,7 @@ class AppAnnotationService: @classmethod def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None: - annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() if not annotation: return None @@ -361,7 +357,7 @@ class AppAnnotationService: score: float, ): # add hit count to annotation - db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).update( + db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).update( {MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False ) @@ -384,16 +380,14 @@ class AppAnnotationService: # get app info app = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) if not app: raise NotFound("App not found") - annotation_setting = ( - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() - ) + annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail return { @@ -412,7 +406,7 @@ class AppAnnotationService: # get app info app = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) @@ -421,7 +415,7 @@ class AppAnnotationService: annotation_setting = ( db.session.query(AppAnnotationSetting) - .filter( + .where( AppAnnotationSetting.app_id == app_id, AppAnnotationSetting.id == annotation_setting_id, ) diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py index 601d67d2f..457c91e5c 100644 --- a/api/services/api_based_extension_service.py +++ b/api/services/api_based_extension_service.py @@ -73,7 +73,7 @@ class APIBasedExtensionService: db.session.query(APIBasedExtension) .filter_by(tenant_id=extension_data.tenant_id) .filter_by(name=extension_data.name) - .filter(APIBasedExtension.id != extension_data.id) + .where(APIBasedExtension.id != extension_data.id) .first() ) diff --git a/api/services/app_service.py b/api/services/app_service.py index cfcb414de..0b6b85bcb 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -382,7 +382,7 @@ class AppService: elif provider_type == "api": try: provider: Optional[ApiToolProvider] = ( - db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first() + db.session.query(ApiToolProvider).where(ApiToolProvider.id == provider_id).first() ) if provider is None: raise ValueError(f"provider not found for tool {tool_name}") @@ -399,7 +399,7 @@ class AppService: :param app_id: app id :return: app code """ - site = db.session.query(Site).filter(Site.app_id == app_id).first() + site = db.session.query(Site).where(Site.app_id == app_id).first() if not site: raise ValueError(f"App with id {app_id} not found") return str(site.code) @@ -411,7 +411,7 @@ class AppService: :param app_code: app code :return: app id """ - site = db.session.query(Site).filter(Site.code == app_code).first() + site = db.session.query(Site).where(Site.code == app_code).first() if not site: raise ValueError(f"App with code {app_code} not found") return str(site.app_id) diff --git a/api/services/audio_service.py b/api/services/audio_service.py index e8923eb51..0084eebb3 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -135,7 +135,7 @@ class AudioService: uuid.UUID(message_id) except ValueError: return None - message = db.session.query(Message).filter(Message.id == message_id).first() + message = db.session.query(Message).where(Message.id == message_id).first() if message is None: return None if message.answer == "" and message.status == MessageStatus.NORMAL: diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py index e5f4a3ef6..996e9187f 100644 --- a/api/services/auth/api_key_auth_service.py +++ b/api/services/auth/api_key_auth_service.py @@ -11,7 +11,7 @@ class ApiKeyAuthService: def get_provider_auth_list(tenant_id: str) -> list: data_source_api_key_bindings = ( db.session.query(DataSourceApiKeyAuthBinding) - .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)) + .where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)) .all() ) return data_source_api_key_bindings @@ -36,7 +36,7 @@ class ApiKeyAuthService: def get_auth_credentials(tenant_id: str, category: str, provider: str): data_source_api_key_bindings = ( db.session.query(DataSourceApiKeyAuthBinding) - .filter( + .where( DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.category == category, DataSourceApiKeyAuthBinding.provider == provider, @@ -53,7 +53,7 @@ class ApiKeyAuthService: def delete_provider_auth(tenant_id: str, binding_id: str): data_source_api_key_binding = ( db.session.query(DataSourceApiKeyAuthBinding) - .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id) + .where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id) .first() ) if data_source_api_key_binding: diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 9fffde073..5a12aa2e5 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -75,7 +75,7 @@ class BillingService: join: Optional[TenantAccountJoin] = ( db.session.query(TenantAccountJoin) - .filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) + .where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) .first() ) diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index ddd16b2e0..ad9b750d4 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -24,13 +24,13 @@ class ClearFreePlanTenantExpiredLogs: @classmethod def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int): with flask_app.app_context(): - apps = db.session.query(App).filter(App.tenant_id == tenant_id).all() + apps = db.session.query(App).where(App.tenant_id == tenant_id).all() app_ids = [app.id for app in apps] while True: with Session(db.engine).no_autoflush as session: messages = ( session.query(Message) - .filter( + .where( Message.app_id.in_(app_ids), Message.created_at < datetime.datetime.now() - datetime.timedelta(days=days), ) @@ -54,7 +54,7 @@ class ClearFreePlanTenantExpiredLogs: message_ids = [message.id for message in messages] # delete messages - session.query(Message).filter( + session.query(Message).where( Message.id.in_(message_ids), ).delete(synchronize_session=False) @@ -70,7 +70,7 @@ class ClearFreePlanTenantExpiredLogs: with Session(db.engine).no_autoflush as session: conversations = ( session.query(Conversation) - .filter( + .where( Conversation.app_id.in_(app_ids), Conversation.updated_at < datetime.datetime.now() - datetime.timedelta(days=days), ) @@ -93,7 +93,7 @@ class ClearFreePlanTenantExpiredLogs: ) conversation_ids = [conversation.id for conversation in conversations] - session.query(Conversation).filter( + session.query(Conversation).where( Conversation.id.in_(conversation_ids), ).delete(synchronize_session=False) session.commit() @@ -276,7 +276,7 @@ class ClearFreePlanTenantExpiredLogs: for test_interval in test_intervals: tenant_count = ( session.query(Tenant.id) - .filter(Tenant.created_at.between(current_time, current_time + test_interval)) + .where(Tenant.created_at.between(current_time, current_time + test_interval)) .count() ) if tenant_count <= 100: @@ -301,7 +301,7 @@ class ClearFreePlanTenantExpiredLogs: rs = ( session.query(Tenant.id) - .filter(Tenant.created_at.between(current_time, batch_end)) + .where(Tenant.created_at.between(current_time, batch_end)) .order_by(Tenant.created_at) ) diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 40097d5ed..525c87fe4 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -123,7 +123,7 @@ class ConversationService: # get conversation first message message = ( db.session.query(Message) - .filter(Message.app_id == app_model.id, Message.conversation_id == conversation.id) + .where(Message.app_id == app_model.id, Message.conversation_id == conversation.id) .order_by(Message.created_at.asc()) .first() ) @@ -148,7 +148,7 @@ class ConversationService: def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): conversation = ( db.session.query(Conversation) - .filter( + .where( Conversation.id == conversation_id, Conversation.app_id == app_model.id, Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index ce597420d..4872702a7 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -80,7 +80,7 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde class DatasetService: @staticmethod def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False): - query = select(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc()) + query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc()) if user: # get permitted dataset ids @@ -92,14 +92,14 @@ class DatasetService: if user.current_role == TenantAccountRole.DATASET_OPERATOR: # only show datasets that the user has permission to access if permitted_dataset_ids: - query = query.filter(Dataset.id.in_(permitted_dataset_ids)) + query = query.where(Dataset.id.in_(permitted_dataset_ids)) else: return [], 0 else: if user.current_role != TenantAccountRole.OWNER or not include_all: # show all datasets that the user has permission to access if permitted_dataset_ids: - query = query.filter( + query = query.where( db.or_( Dataset.permission == DatasetPermissionEnum.ALL_TEAM, db.and_( @@ -112,7 +112,7 @@ class DatasetService: ) ) else: - query = query.filter( + query = query.where( db.or_( Dataset.permission == DatasetPermissionEnum.ALL_TEAM, db.and_( @@ -122,15 +122,15 @@ class DatasetService: ) else: # if no user, only show datasets that are shared with all team members - query = query.filter(Dataset.permission == DatasetPermissionEnum.ALL_TEAM) + query = query.where(Dataset.permission == DatasetPermissionEnum.ALL_TEAM) if search: - query = query.filter(Dataset.name.ilike(f"%{search}%")) + query = query.where(Dataset.name.ilike(f"%{search}%")) if tag_ids: target_ids = TagService.get_target_ids_by_tag_ids("knowledge", tenant_id, tag_ids) if target_ids: - query = query.filter(Dataset.id.in_(target_ids)) + query = query.where(Dataset.id.in_(target_ids)) else: return [], 0 @@ -143,7 +143,7 @@ class DatasetService: # get the latest process rule dataset_process_rule = ( db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.dataset_id == dataset_id) + .where(DatasetProcessRule.dataset_id == dataset_id) .order_by(DatasetProcessRule.created_at.desc()) .limit(1) .one_or_none() @@ -158,7 +158,7 @@ class DatasetService: @staticmethod def get_datasets_by_ids(ids, tenant_id): - stmt = select(Dataset).filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id) + stmt = select(Dataset).where(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id) datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False) @@ -697,7 +697,7 @@ class DatasetService: def get_related_apps(dataset_id: str): return ( db.session.query(AppDatasetJoin) - .filter(AppDatasetJoin.dataset_id == dataset_id) + .where(AppDatasetJoin.dataset_id == dataset_id) .order_by(db.desc(AppDatasetJoin.created_at)) .all() ) @@ -714,7 +714,7 @@ class DatasetService: start_date = datetime.datetime.now() - datetime.timedelta(days=30) dataset_auto_disable_logs = ( db.session.query(DatasetAutoDisableLog) - .filter( + .where( DatasetAutoDisableLog.dataset_id == dataset_id, DatasetAutoDisableLog.created_at >= start_date, ) @@ -843,7 +843,7 @@ class DocumentService: def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]: if document_id: document = ( - db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) return document else: @@ -851,7 +851,7 @@ class DocumentService: @staticmethod def get_document_by_id(document_id: str) -> Optional[Document]: - document = db.session.query(Document).filter(Document.id == document_id).first() + document = db.session.query(Document).where(Document.id == document_id).first() return document @@ -859,7 +859,7 @@ class DocumentService: def get_document_by_ids(document_ids: list[str]) -> list[Document]: documents = ( db.session.query(Document) - .filter( + .where( Document.id.in_(document_ids), Document.enabled == True, Document.indexing_status == "completed", @@ -873,7 +873,7 @@ class DocumentService: def get_document_by_dataset_id(dataset_id: str) -> list[Document]: documents = ( db.session.query(Document) - .filter( + .where( Document.dataset_id == dataset_id, Document.enabled == True, ) @@ -886,7 +886,7 @@ class DocumentService: def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]: documents = ( db.session.query(Document) - .filter( + .where( Document.dataset_id == dataset_id, Document.enabled == True, Document.indexing_status == "completed", @@ -901,7 +901,7 @@ class DocumentService: def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]: documents = ( db.session.query(Document) - .filter(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) + .where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) .all() ) return documents @@ -910,7 +910,7 @@ class DocumentService: def get_batch_documents(dataset_id: str, batch: str) -> list[Document]: documents = ( db.session.query(Document) - .filter( + .where( Document.batch == batch, Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id, @@ -922,7 +922,7 @@ class DocumentService: @staticmethod def get_document_file_detail(file_id: str): - file_detail = db.session.query(UploadFile).filter(UploadFile.id == file_id).one_or_none() + file_detail = db.session.query(UploadFile).where(UploadFile.id == file_id).one_or_none() return file_detail @staticmethod @@ -950,7 +950,7 @@ class DocumentService: @staticmethod def delete_documents(dataset: Dataset, document_ids: list[str]): - documents = db.session.query(Document).filter(Document.id.in_(document_ids)).all() + documents = db.session.query(Document).where(Document.id.in_(document_ids)).all() file_ids = [ document.data_source_info_dict["upload_file_id"] for document in documents @@ -1189,7 +1189,7 @@ class DocumentService: for file_id in upload_file_list: file = ( db.session.query(UploadFile) - .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) .first() ) @@ -1270,7 +1270,7 @@ class DocumentService: workspace_id = notion_info.workspace_id data_source_binding = ( db.session.query(DataSourceOauthBinding) - .filter( + .where( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.provider == "notion", @@ -1413,7 +1413,7 @@ class DocumentService: def get_tenant_documents_count(): documents_count = ( db.session.query(Document) - .filter( + .where( Document.completed_at.isnot(None), Document.enabled == True, Document.archived == False, @@ -1469,7 +1469,7 @@ class DocumentService: for file_id in upload_file_list: file = ( db.session.query(UploadFile) - .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) .first() ) @@ -1489,7 +1489,7 @@ class DocumentService: workspace_id = notion_info.workspace_id data_source_binding = ( db.session.query(DataSourceOauthBinding) - .filter( + .where( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.provider == "notion", @@ -2005,7 +2005,7 @@ class SegmentService: with redis_client.lock(lock_name, timeout=600): max_position = ( db.session.query(func.max(DocumentSegment.position)) - .filter(DocumentSegment.document_id == document.id) + .where(DocumentSegment.document_id == document.id) .scalar() ) segment_document = DocumentSegment( @@ -2043,7 +2043,7 @@ class SegmentService: segment_document.status = "error" segment_document.error = str(e) db.session.commit() - segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first() + segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first() return segment @classmethod @@ -2062,7 +2062,7 @@ class SegmentService: ) max_position = ( db.session.query(func.max(DocumentSegment.position)) - .filter(DocumentSegment.document_id == document.id) + .where(DocumentSegment.document_id == document.id) .scalar() ) pre_segment_data_list = [] @@ -2201,7 +2201,7 @@ class SegmentService: # get the process rule processing_rule = ( db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.id == document.dataset_process_rule_id) + .where(DatasetProcessRule.id == document.dataset_process_rule_id) .first() ) if not processing_rule: @@ -2276,7 +2276,7 @@ class SegmentService: # get the process rule processing_rule = ( db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.id == document.dataset_process_rule_id) + .where(DatasetProcessRule.id == document.dataset_process_rule_id) .first() ) if not processing_rule: @@ -2295,7 +2295,7 @@ class SegmentService: segment.status = "error" segment.error = str(e) db.session.commit() - new_segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first() + new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first() return new_segment @classmethod @@ -2321,7 +2321,7 @@ class SegmentService: index_node_ids = ( db.session.query(DocumentSegment) .with_entities(DocumentSegment.index_node_id) - .filter( + .where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, DocumentSegment.document_id == document.id, @@ -2332,7 +2332,7 @@ class SegmentService: index_node_ids = [index_node_id[0] for index_node_id in index_node_ids] delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id) - db.session.query(DocumentSegment).filter(DocumentSegment.id.in_(segment_ids)).delete() + db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete() db.session.commit() @classmethod @@ -2340,7 +2340,7 @@ class SegmentService: if action == "enable": segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, DocumentSegment.document_id == document.id, @@ -2367,7 +2367,7 @@ class SegmentService: elif action == "disable": segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, DocumentSegment.document_id == document.id, @@ -2404,7 +2404,7 @@ class SegmentService: index_node_hash = helper.generate_text_hash(content) child_chunk_count = ( db.session.query(ChildChunk) - .filter( + .where( ChildChunk.tenant_id == current_user.current_tenant_id, ChildChunk.dataset_id == dataset.id, ChildChunk.document_id == document.id, @@ -2414,7 +2414,7 @@ class SegmentService: ) max_position = ( db.session.query(func.max(ChildChunk.position)) - .filter( + .where( ChildChunk.tenant_id == current_user.current_tenant_id, ChildChunk.dataset_id == dataset.id, ChildChunk.document_id == document.id, @@ -2457,7 +2457,7 @@ class SegmentService: ) -> list[ChildChunk]: child_chunks = ( db.session.query(ChildChunk) - .filter( + .where( ChildChunk.dataset_id == dataset.id, ChildChunk.document_id == document.id, ChildChunk.segment_id == segment.id, @@ -2578,7 +2578,7 @@ class SegmentService: """Get a child chunk by its ID.""" result = ( db.session.query(ChildChunk) - .filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id) + .where(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id) .first() ) return result if isinstance(result, ChildChunk) else None @@ -2594,15 +2594,15 @@ class SegmentService: limit: int = 20, ): """Get segments for a document with optional filtering.""" - query = select(DocumentSegment).filter( + query = select(DocumentSegment).where( DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id ) if status_list: - query = query.filter(DocumentSegment.status.in_(status_list)) + query = query.where(DocumentSegment.status.in_(status_list)) if keyword: - query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%")) + query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) query = query.order_by(DocumentSegment.position.asc()) paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) @@ -2615,7 +2615,7 @@ class SegmentService: ) -> tuple[DocumentSegment, Document]: """Update a segment by its ID with validation and checks.""" # check dataset - dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") @@ -2647,7 +2647,7 @@ class SegmentService: # check segment segment = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id) + .where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id) .first() ) if not segment: @@ -2664,7 +2664,7 @@ class SegmentService: """Get a segment by its ID.""" result = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id) + .where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id) .first() ) return result if isinstance(result, DocumentSegment) else None @@ -2677,7 +2677,7 @@ class DatasetCollectionBindingService: ) -> DatasetCollectionBinding: dataset_collection_binding = ( db.session.query(DatasetCollectionBinding) - .filter( + .where( DatasetCollectionBinding.provider_name == provider_name, DatasetCollectionBinding.model_name == model_name, DatasetCollectionBinding.type == collection_type, @@ -2703,7 +2703,7 @@ class DatasetCollectionBindingService: ) -> DatasetCollectionBinding: dataset_collection_binding = ( db.session.query(DatasetCollectionBinding) - .filter( + .where( DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type ) .order_by(DatasetCollectionBinding.created_at) @@ -2722,7 +2722,7 @@ class DatasetPermissionService: db.session.query( DatasetPermission.account_id, ) - .filter(DatasetPermission.dataset_id == dataset_id) + .where(DatasetPermission.dataset_id == dataset_id) .all() ) @@ -2735,7 +2735,7 @@ class DatasetPermissionService: @classmethod def update_partial_member_list(cls, tenant_id, dataset_id, user_list): try: - db.session.query(DatasetPermission).filter(DatasetPermission.dataset_id == dataset_id).delete() + db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete() permissions = [] for user in user_list: permission = DatasetPermission( @@ -2771,7 +2771,7 @@ class DatasetPermissionService: @classmethod def clear_partial_member_list(cls, dataset_id): try: - db.session.query(DatasetPermission).filter(DatasetPermission.dataset_id == dataset_id).delete() + db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete() db.session.commit() except Exception as e: db.session.rollback() diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 06a4c2211..b7af03e91 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -30,11 +30,11 @@ class ExternalDatasetService: ) -> tuple[list[ExternalKnowledgeApis], int | None]: query = ( select(ExternalKnowledgeApis) - .filter(ExternalKnowledgeApis.tenant_id == tenant_id) + .where(ExternalKnowledgeApis.tenant_id == tenant_id) .order_by(ExternalKnowledgeApis.created_at.desc()) ) if search: - query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%")) + query = query.where(ExternalKnowledgeApis.name.ilike(f"%{search}%")) external_knowledge_apis = db.paginate( select=query, page=page, per_page=per_page, max_per_page=100, error_out=False diff --git a/api/services/file_service.py b/api/services/file_service.py index 286535bd1..e234c2f32 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -144,7 +144,7 @@ class FileService: @staticmethod def get_file_preview(file_id: str): - upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found") @@ -167,7 +167,7 @@ class FileService: if not result: raise NotFound("File not found or signature is invalid") - upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found or signature is invalid") @@ -187,7 +187,7 @@ class FileService: if not result: raise NotFound("File not found or signature is invalid") - upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found or signature is invalid") @@ -198,7 +198,7 @@ class FileService: @staticmethod def get_public_image_preview(file_id: str): - upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found or signature is invalid") diff --git a/api/services/message_service.py b/api/services/message_service.py index 51b070ece..283b7b9b4 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -50,7 +50,7 @@ class MessageService: if first_id: first_message = ( db.session.query(Message) - .filter(Message.conversation_id == conversation.id, Message.id == first_id) + .where(Message.conversation_id == conversation.id, Message.id == first_id) .first() ) @@ -59,7 +59,7 @@ class MessageService: history_messages = ( db.session.query(Message) - .filter( + .where( Message.conversation_id == conversation.id, Message.created_at < first_message.created_at, Message.id != first_message.id, @@ -71,7 +71,7 @@ class MessageService: else: history_messages = ( db.session.query(Message) - .filter(Message.conversation_id == conversation.id) + .where(Message.conversation_id == conversation.id) .order_by(Message.created_at.desc()) .limit(fetch_limit) .all() @@ -109,19 +109,19 @@ class MessageService: app_model=app_model, user=user, conversation_id=conversation_id ) - base_query = base_query.filter(Message.conversation_id == conversation.id) + base_query = base_query.where(Message.conversation_id == conversation.id) if include_ids is not None: - base_query = base_query.filter(Message.id.in_(include_ids)) + base_query = base_query.where(Message.id.in_(include_ids)) if last_id: - last_message = base_query.filter(Message.id == last_id).first() + last_message = base_query.where(Message.id == last_id).first() if not last_message: raise LastMessageNotExistsError() history_messages = ( - base_query.filter(Message.created_at < last_message.created_at, Message.id != last_message.id) + base_query.where(Message.created_at < last_message.created_at, Message.id != last_message.id) .order_by(Message.created_at.desc()) .limit(fetch_limit) .all() @@ -183,7 +183,7 @@ class MessageService: offset = (page - 1) * limit feedbacks = ( db.session.query(MessageFeedback) - .filter(MessageFeedback.app_id == app_model.id) + .where(MessageFeedback.app_id == app_model.id) .order_by(MessageFeedback.created_at.desc(), MessageFeedback.id.desc()) .limit(limit) .offset(offset) @@ -196,7 +196,7 @@ class MessageService: def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): message = ( db.session.query(Message) - .filter( + .where( Message.id == message_id, Message.app_id == app_model.id, Message.from_source == ("api" if isinstance(user, EndUser) else "console"), @@ -248,9 +248,7 @@ class MessageService: if not conversation.override_model_configs: app_model_config = ( db.session.query(AppModelConfig) - .filter( - AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id - ) + .where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) .first() ) else: diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 26311a637..a200cfa14 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -103,7 +103,7 @@ class ModelLoadBalancingService: # Get load balancing configurations load_balancing_configs = ( db.session.query(LoadBalancingModelConfig) - .filter( + .where( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), @@ -219,7 +219,7 @@ class ModelLoadBalancingService: # Get load balancing configurations load_balancing_model_config = ( db.session.query(LoadBalancingModelConfig) - .filter( + .where( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), @@ -307,7 +307,7 @@ class ModelLoadBalancingService: current_load_balancing_configs = ( db.session.query(LoadBalancingModelConfig) - .filter( + .where( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), @@ -457,7 +457,7 @@ class ModelLoadBalancingService: # Get load balancing config load_balancing_model_config = ( db.session.query(LoadBalancingModelConfig) - .filter( + .where( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider, LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), diff --git a/api/services/ops_service.py b/api/services/ops_service.py index dbeb4f190..62f37c158 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -17,7 +17,7 @@ class OpsService: """ trace_config_data: Optional[TraceAppConfig] = ( db.session.query(TraceAppConfig) - .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() ) @@ -25,7 +25,7 @@ class OpsService: return None # decrypt_token and obfuscated_token - app = db.session.query(App).filter(App.id == app_id).first() + app = db.session.query(App).where(App.id == app_id).first() if not app: return None tenant_id = app.tenant_id @@ -148,7 +148,7 @@ class OpsService: # check if trace config already exists trace_config_data: Optional[TraceAppConfig] = ( db.session.query(TraceAppConfig) - .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() ) @@ -156,7 +156,7 @@ class OpsService: return None # get tenant id - app = db.session.query(App).filter(App.id == app_id).first() + app = db.session.query(App).where(App.id == app_id).first() if not app: return None tenant_id = app.tenant_id @@ -190,7 +190,7 @@ class OpsService: # check if trace config already exists current_trace_config = ( db.session.query(TraceAppConfig) - .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() ) @@ -198,7 +198,7 @@ class OpsService: return None # get tenant id - app = db.session.query(App).filter(App.id == app_id).first() + app = db.session.query(App).where(App.id == app_id).first() if not app: return None tenant_id = app.tenant_id @@ -227,7 +227,7 @@ class OpsService: """ trace_config = ( db.session.query(TraceAppConfig) - .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() ) diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index dbaaa7160..1806fbcfd 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -101,7 +101,7 @@ class PluginMigration: for test_interval in test_intervals: tenant_count = ( session.query(Tenant.id) - .filter(Tenant.created_at.between(current_time, current_time + test_interval)) + .where(Tenant.created_at.between(current_time, current_time + test_interval)) .count() ) if tenant_count <= 100: @@ -126,7 +126,7 @@ class PluginMigration: rs = ( session.query(Tenant.id) - .filter(Tenant.created_at.between(current_time, batch_end)) + .where(Tenant.created_at.between(current_time, batch_end)) .order_by(Tenant.created_at) ) @@ -212,7 +212,7 @@ class PluginMigration: Extract tool tables. """ with Session(db.engine) as session: - rs = session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() + rs = session.query(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant_id).all() result = [] for row in rs: result.append(ToolProviderID(row.provider).plugin_id) @@ -226,7 +226,7 @@ class PluginMigration: """ with Session(db.engine) as session: - rs = session.query(Workflow).filter(Workflow.tenant_id == tenant_id).all() + rs = session.query(Workflow).where(Workflow.tenant_id == tenant_id).all() result = [] for row in rs: graph = row.graph_dict @@ -249,7 +249,7 @@ class PluginMigration: Extract app tables. """ with Session(db.engine) as session: - apps = session.query(App).filter(App.tenant_id == tenant_id).all() + apps = session.query(App).where(App.tenant_id == tenant_id).all() if not apps: return [] @@ -257,7 +257,7 @@ class PluginMigration: app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value ] - rs = session.query(AppModelConfig).filter(AppModelConfig.id.in_(agent_app_model_config_ids)).all() + rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all() result = [] for row in rs: agent_config = row.agent_mode_dict diff --git a/api/services/plugin/plugin_parameter_service.py b/api/services/plugin/plugin_parameter_service.py index a1c5639e0..00b59dacb 100644 --- a/api/services/plugin/plugin_parameter_service.py +++ b/api/services/plugin/plugin_parameter_service.py @@ -51,7 +51,7 @@ class PluginParameterService: with Session(db.engine) as session: db_record = ( session.query(BuiltinToolProvider) - .filter( + .where( BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider, ) diff --git a/api/services/plugin/plugin_permission_service.py b/api/services/plugin/plugin_permission_service.py index 275e49603..60fa26964 100644 --- a/api/services/plugin/plugin_permission_service.py +++ b/api/services/plugin/plugin_permission_service.py @@ -8,7 +8,7 @@ class PluginPermissionService: @staticmethod def get_permission(tenant_id: str) -> TenantPluginPermission | None: with Session(db.engine) as session: - return session.query(TenantPluginPermission).filter(TenantPluginPermission.tenant_id == tenant_id).first() + return session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first() @staticmethod def change_permission( @@ -18,7 +18,7 @@ class PluginPermissionService: ): with Session(db.engine) as session: permission = ( - session.query(TenantPluginPermission).filter(TenantPluginPermission.tenant_id == tenant_id).first() + session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first() ) if not permission: permission = TenantPluginPermission( diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py index 3295516cc..b97d13d01 100644 --- a/api/services/recommend_app/database/database_retrieval.py +++ b/api/services/recommend_app/database/database_retrieval.py @@ -33,14 +33,14 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): """ recommended_apps = ( db.session.query(RecommendedApp) - .filter(RecommendedApp.is_listed == True, RecommendedApp.language == language) + .where(RecommendedApp.is_listed == True, RecommendedApp.language == language) .all() ) if len(recommended_apps) == 0: recommended_apps = ( db.session.query(RecommendedApp) - .filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) + .where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) .all() ) @@ -83,7 +83,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): # is in public recommended list recommended_app = ( db.session.query(RecommendedApp) - .filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id) + .where(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id) .first() ) @@ -91,7 +91,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): return None # get app detail - app_model = db.session.query(App).filter(App.id == app_id).first() + app_model = db.session.query(App).where(App.id == app_id).first() if not app_model or not app_model.is_public: return None diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index 4cb870011..641e03c3c 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -17,7 +17,7 @@ class SavedMessageService: raise ValueError("User is required") saved_messages = ( db.session.query(SavedMessage) - .filter( + .where( SavedMessage.app_id == app_model.id, SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), SavedMessage.created_by == user.id, @@ -37,7 +37,7 @@ class SavedMessageService: return saved_message = ( db.session.query(SavedMessage) - .filter( + .where( SavedMessage.app_id == app_model.id, SavedMessage.message_id == message_id, SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), @@ -67,7 +67,7 @@ class SavedMessageService: return saved_message = ( db.session.query(SavedMessage) - .filter( + .where( SavedMessage.app_id == app_model.id, SavedMessage.message_id == message_id, SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 74c6150b4..75fa52a75 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -16,10 +16,10 @@ class TagService: query = ( db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count")) .outerjoin(TagBinding, Tag.id == TagBinding.tag_id) - .filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) + .where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) ) if keyword: - query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%"))) + query = query.where(db.and_(Tag.name.ilike(f"%{keyword}%"))) query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at) results: list = query.order_by(Tag.created_at.desc()).all() return results @@ -28,7 +28,7 @@ class TagService: def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list: tags = ( db.session.query(Tag) - .filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) + .where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) .all() ) if not tags: @@ -36,7 +36,7 @@ class TagService: tag_ids = [tag.id for tag in tags] tag_bindings = ( db.session.query(TagBinding.target_id) - .filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id) + .where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id) .all() ) if not tag_bindings: @@ -50,7 +50,7 @@ class TagService: return [] tags = ( db.session.query(Tag) - .filter(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type) + .where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type) .all() ) if not tags: @@ -62,7 +62,7 @@ class TagService: tags = ( db.session.query(Tag) .join(TagBinding, Tag.id == TagBinding.tag_id) - .filter( + .where( TagBinding.target_id == target_id, TagBinding.tenant_id == current_tenant_id, Tag.tenant_id == current_tenant_id, @@ -92,7 +92,7 @@ class TagService: def update_tags(args: dict, tag_id: str) -> Tag: if TagService.get_tag_by_tag_name(args.get("type", ""), current_user.current_tenant_id, args.get("name", "")): raise ValueError("Tag name already exists") - tag = db.session.query(Tag).filter(Tag.id == tag_id).first() + tag = db.session.query(Tag).where(Tag.id == tag_id).first() if not tag: raise NotFound("Tag not found") tag.name = args["name"] @@ -101,17 +101,17 @@ class TagService: @staticmethod def get_tag_binding_count(tag_id: str) -> int: - count = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).count() + count = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).count() return count @staticmethod def delete_tag(tag_id: str): - tag = db.session.query(Tag).filter(Tag.id == tag_id).first() + tag = db.session.query(Tag).where(Tag.id == tag_id).first() if not tag: raise NotFound("Tag not found") db.session.delete(tag) # delete tag binding - tag_bindings = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).all() + tag_bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).all() if tag_bindings: for tag_binding in tag_bindings: db.session.delete(tag_binding) @@ -125,7 +125,7 @@ class TagService: for tag_id in args["tag_ids"]: tag_binding = ( db.session.query(TagBinding) - .filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"]) + .where(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"]) .first() ) if tag_binding: @@ -146,7 +146,7 @@ class TagService: # delete tag binding tag_bindings = ( db.session.query(TagBinding) - .filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"])) + .where(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"])) .first() ) if tag_bindings: @@ -158,7 +158,7 @@ class TagService: if type == "knowledge": dataset = ( db.session.query(Dataset) - .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id) + .where(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id) .first() ) if not dataset: @@ -166,7 +166,7 @@ class TagService: elif type == "app": app = ( db.session.query(App) - .filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id) + .where(App.tenant_id == current_user.current_tenant_id, App.id == target_id) .first() ) if not app: diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 80badf233..78e587abe 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -119,7 +119,7 @@ class ApiToolManageService: # check if the provider exists provider = ( db.session.query(ApiToolProvider) - .filter( + .where( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == provider_name, ) @@ -210,7 +210,7 @@ class ApiToolManageService: """ provider: ApiToolProvider | None = ( db.session.query(ApiToolProvider) - .filter( + .where( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == provider_name, ) @@ -257,7 +257,7 @@ class ApiToolManageService: # check if the provider exists provider = ( db.session.query(ApiToolProvider) - .filter( + .where( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == original_provider, ) @@ -326,7 +326,7 @@ class ApiToolManageService: """ provider = ( db.session.query(ApiToolProvider) - .filter( + .where( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == provider_name, ) @@ -376,7 +376,7 @@ class ApiToolManageService: db_provider = ( db.session.query(ApiToolProvider) - .filter( + .where( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == provider_name, ) @@ -444,7 +444,7 @@ class ApiToolManageService: """ # get all api providers db_providers: list[ApiToolProvider] = ( - db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or [] + db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() or [] ) result: list[ToolProviderApiEntity] = [] diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index b8e3ce265..65f05d298 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -154,7 +154,7 @@ class BuiltinToolManageService: # get if the provider exists db_provider = ( session.query(BuiltinToolProvider) - .filter( + .where( BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.id == credential_id, ) @@ -404,7 +404,7 @@ class BuiltinToolManageService: with Session(db.engine) as session: db_provider = ( session.query(BuiltinToolProvider) - .filter( + .where( BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.id == credential_id, ) @@ -613,7 +613,7 @@ class BuiltinToolManageService: if provider_id_entity.organization != "langgenius": provider = ( session.query(BuiltinToolProvider) - .filter( + .where( BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == full_provider_name, ) @@ -626,7 +626,7 @@ class BuiltinToolManageService: else: provider = ( session.query(BuiltinToolProvider) - .filter( + .where( BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_name) | (BuiltinToolProvider.provider == full_provider_name), @@ -647,7 +647,7 @@ class BuiltinToolManageService: # it's an old provider without organization return ( session.query(BuiltinToolProvider) - .filter(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name) + .where(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name) .order_by( BuiltinToolProvider.is_default.desc(), # default=True first BuiltinToolProvider.created_at.asc(), # oldest first diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index c0126a0f4..23be449a5 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -31,7 +31,7 @@ class MCPToolManageService: def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider: res = ( db.session.query(MCPToolProvider) - .filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id) + .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id) .first() ) if not res: @@ -42,7 +42,7 @@ class MCPToolManageService: def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider: res = ( db.session.query(MCPToolProvider) - .filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier) + .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier) .first() ) if not res: @@ -63,7 +63,7 @@ class MCPToolManageService: server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() existing_provider = ( db.session.query(MCPToolProvider) - .filter( + .where( MCPToolProvider.tenant_id == tenant_id, or_( MCPToolProvider.name == name, @@ -100,7 +100,7 @@ class MCPToolManageService: def retrieve_mcp_tools(tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]: mcp_providers = ( db.session.query(MCPToolProvider) - .filter(MCPToolProvider.tenant_id == tenant_id) + .where(MCPToolProvider.tenant_id == tenant_id) .order_by(MCPToolProvider.name) .all() ) diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index c6b205557..75da5e5ea 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -43,7 +43,7 @@ class WorkflowToolManageService: # check if the name is unique existing_workflow_tool_provider = ( db.session.query(WorkflowToolProvider) - .filter( + .where( WorkflowToolProvider.tenant_id == tenant_id, # name or app_id or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id), @@ -54,7 +54,7 @@ class WorkflowToolManageService: if existing_workflow_tool_provider is not None: raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists") - app: App | None = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first() + app: App | None = db.session.query(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).first() if app is None: raise ValueError(f"App {workflow_app_id} not found") @@ -123,7 +123,7 @@ class WorkflowToolManageService: # check if the name is unique existing_workflow_tool_provider = ( db.session.query(WorkflowToolProvider) - .filter( + .where( WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.name == name, WorkflowToolProvider.id != workflow_tool_id, @@ -136,7 +136,7 @@ class WorkflowToolManageService: workflow_tool_provider: WorkflowToolProvider | None = ( db.session.query(WorkflowToolProvider) - .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .first() ) @@ -144,7 +144,7 @@ class WorkflowToolManageService: raise ValueError(f"Tool {workflow_tool_id} not found") app: App | None = ( - db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first() + db.session.query(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first() ) if app is None: @@ -186,7 +186,7 @@ class WorkflowToolManageService: :param tenant_id: the tenant id :return: the list of tools """ - db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() + db_tools = db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all() tools: list[WorkflowToolProviderController] = [] for provider in db_tools: @@ -224,7 +224,7 @@ class WorkflowToolManageService: :param tenant_id: the tenant id :param workflow_tool_id: the workflow tool id """ - db.session.query(WorkflowToolProvider).filter( + db.session.query(WorkflowToolProvider).where( WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id ).delete() @@ -243,7 +243,7 @@ class WorkflowToolManageService: """ db_tool: WorkflowToolProvider | None = ( db.session.query(WorkflowToolProvider) - .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .first() ) return cls._get_workflow_tool(tenant_id, db_tool) @@ -259,7 +259,7 @@ class WorkflowToolManageService: """ db_tool: WorkflowToolProvider | None = ( db.session.query(WorkflowToolProvider) - .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) .first() ) return cls._get_workflow_tool(tenant_id, db_tool) @@ -275,7 +275,7 @@ class WorkflowToolManageService: raise ValueError("Tool not found") workflow_app: App | None = ( - db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first() + db.session.query(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first() ) if workflow_app is None: @@ -318,7 +318,7 @@ class WorkflowToolManageService: """ db_tool: WorkflowToolProvider | None = ( db.session.query(WorkflowToolProvider) - .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .first() ) diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 916513919..f9ec05459 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -36,7 +36,7 @@ class VectorService: # get the process rule processing_rule = ( db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .first() ) if not processing_rule: diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index f698ed308..c48e24f24 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -65,7 +65,7 @@ class WebConversationService: return pinned_conversation = ( db.session.query(PinnedConversation) - .filter( + .where( PinnedConversation.app_id == app_model.id, PinnedConversation.conversation_id == conversation_id, PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), @@ -97,7 +97,7 @@ class WebConversationService: return pinned_conversation = ( db.session.query(PinnedConversation) - .filter( + .where( PinnedConversation.app_id == app_model.id, PinnedConversation.conversation_id == conversation_id, PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 8f92b3f07..a9df8d0d7 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -52,7 +52,7 @@ class WebAppAuthService: @classmethod def get_user_through_email(cls, email: str): - account = db.session.query(Account).filter(Account.email == email).first() + account = db.session.query(Account).where(Account.email == email).first() if not account: return None @@ -91,10 +91,10 @@ class WebAppAuthService: @classmethod def create_end_user(cls, app_code, email) -> EndUser: - site = db.session.query(Site).filter(Site.code == app_code).first() + site = db.session.query(Site).where(Site.code == app_code).first() if not site: raise NotFound("Site not found.") - app_model = db.session.query(App).filter(App.id == site.app_id).first() + app_model = db.session.query(App).where(App.id == site.app_id).first() if not app_model: raise NotFound("App not found.") end_user = EndUser( diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 2b0d57bdf..abf6824d7 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -620,7 +620,7 @@ class WorkflowConverter: """ api_based_extension = ( db.session.query(APIBasedExtension) - .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) .first() ) diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index f306e1f06..3164e010b 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -138,7 +138,7 @@ class WorkflowDraftVariableService: ) def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None: - return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first() + return self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable_id).first() def get_draft_variables_by_selectors( self, @@ -166,7 +166,7 @@ class WorkflowDraftVariableService: def list_variables_without_values(self, app_id: str, page: int, limit: int) -> WorkflowDraftVariableList: criteria = WorkflowDraftVariable.app_id == app_id total = None - query = self._session.query(WorkflowDraftVariable).filter(criteria) + query = self._session.query(WorkflowDraftVariable).where(criteria) if page == 1: total = query.count() variables = ( @@ -185,7 +185,7 @@ class WorkflowDraftVariableService: WorkflowDraftVariable.app_id == app_id, WorkflowDraftVariable.node_id == node_id, ) - query = self._session.query(WorkflowDraftVariable).filter(*criteria) + query = self._session.query(WorkflowDraftVariable).where(*criteria) variables = query.order_by(WorkflowDraftVariable.created_at.desc()).all() return WorkflowDraftVariableList(variables=variables) @@ -328,7 +328,7 @@ class WorkflowDraftVariableService: def delete_workflow_variables(self, app_id: str): ( self._session.query(WorkflowDraftVariable) - .filter(WorkflowDraftVariable.app_id == app_id) + .where(WorkflowDraftVariable.app_id == app_id) .delete(synchronize_session=False) ) @@ -379,7 +379,7 @@ class WorkflowDraftVariableService: if conv_id is not None: conversation = ( self._session.query(Conversation) - .filter( + .where( Conversation.id == conv_id, Conversation.app_id == workflow.app_id, ) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 403e55974..e9f21fc5f 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -89,7 +89,7 @@ class WorkflowService: def is_workflow_exist(self, app_model: App) -> bool: return ( db.session.query(Workflow) - .filter( + .where( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == Workflow.VERSION_DRAFT, @@ -104,7 +104,7 @@ class WorkflowService: # fetch draft workflow by app_model workflow = ( db.session.query(Workflow) - .filter( + .where( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == "draft" ) .first() @@ -117,7 +117,7 @@ class WorkflowService: # fetch published workflow by workflow_id workflow = ( db.session.query(Workflow) - .filter( + .where( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id, @@ -141,7 +141,7 @@ class WorkflowService: # fetch published workflow by workflow_id workflow = ( db.session.query(Workflow) - .filter( + .where( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == app_model.workflow_id, @@ -658,7 +658,7 @@ class WorkflowService: # Check if there's a tool provider using this specific workflow version tool_provider = ( session.query(WorkflowToolProvider) - .filter( + .where( WorkflowToolProvider.tenant_id == workflow.tenant_id, WorkflowToolProvider.app_id == workflow.app_id, WorkflowToolProvider.version == workflow.version, diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index bb35645c5..d4fc68a08 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -25,7 +25,7 @@ class WorkspaceService: # Get role of user tenant_account_join = ( db.session.query(TenantAccountJoin) - .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id) .first() ) assert tenant_account_join is not None, "TenantAccountJoin not found" diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 75d648e1b..204c1a4f5 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -25,7 +25,7 @@ def add_document_to_index_task(dataset_document_id: str): logging.info(click.style("Start add document to index: {}".format(dataset_document_id), fg="green")) start_at = time.perf_counter() - dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document_id).first() + dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first() if not dataset_document: logging.info(click.style("Document not found: {}".format(dataset_document_id), fg="red")) db.session.close() @@ -43,7 +43,7 @@ def add_document_to_index_task(dataset_document_id: str): segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == False, DocumentSegment.status == "completed", @@ -86,12 +86,10 @@ def add_document_to_index_task(dataset_document_id: str): index_processor.load(dataset, documents) # delete auto disable log - db.session.query(DatasetAutoDisableLog).filter( - DatasetAutoDisableLog.document_id == dataset_document.id - ).delete() + db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete() # update segment to enable - db.session.query(DocumentSegment).filter(DocumentSegment.document_id == dataset_document.id).update( + db.session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update( { DocumentSegment.enabled: True, DocumentSegment.disabled_at: None, diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index 6144a4fe3..6d48f5df8 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -29,7 +29,7 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: start_at = time.perf_counter() indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) # get app info - app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() + app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() if app: try: @@ -48,7 +48,7 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: documents.append(document) # if annotation reply is enabled , batch add annotations' index app_annotation_setting = ( - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() ) if app_annotation_setting: diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index 747fce578..5d5d1d3ad 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -19,16 +19,14 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): logging.info(click.style("Start delete app annotations index: {}".format(app_id), fg="green")) start_at = time.perf_counter() # get app info - app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() - annotations_count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).count() + app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() + annotations_count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).count() if not app: logging.info(click.style("App not found: {}".format(app_id), fg="red")) db.session.close() return - app_annotation_setting = ( - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() - ) + app_annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() if not app_annotation_setting: logging.info(click.style("App annotation setting not found: {}".format(app_id), fg="red")) diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index c04f1be84..12d10df44 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -30,14 +30,14 @@ def enable_annotation_reply_task( logging.info(click.style("Start add app annotation to index: {}".format(app_id), fg="green")) start_at = time.perf_counter() # get app info - app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() + app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() if not app: logging.info(click.style("App not found: {}".format(app_id), fg="red")) db.session.close() return - annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).all() + annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).all() enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id)) enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id)) @@ -46,9 +46,7 @@ def enable_annotation_reply_task( dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( embedding_provider_name, embedding_model_name, "annotation" ) - annotation_setting = ( - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() - ) + annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() if annotation_setting: if dataset_collection_binding.id != annotation_setting.collection_binding_id: old_dataset_collection_binding = ( diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index 97efc47b3..49bff72a9 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -27,12 +27,12 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form start_at = time.perf_counter() try: - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: raise Exception("Document has no dataset") - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids)).all() + segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)).all() # check segment is exist if segments: index_node_ids = [segment.index_node_id for segment in segments] @@ -42,7 +42,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) for upload_file_id in image_upload_file_ids: - image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() try: if image_file and image_file.key: storage.delete(image_file.key) @@ -56,7 +56,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form db.session.commit() if file_ids: - files = db.session.query(UploadFile).filter(UploadFile.id.in_(file_ids)).all() + files = db.session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all() for file in files: try: storage.delete(file.key) diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 51b6343fd..64df3175e 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -81,7 +81,7 @@ def batch_create_segment_to_index_task( segment_hash = helper.generate_text_hash(content) # type: ignore max_position = ( db.session.query(func.max(DocumentSegment.position)) - .filter(DocumentSegment.document_id == dataset_document.id) + .where(DocumentSegment.document_id == dataset_document.id) .scalar() ) segment_document = DocumentSegment( diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 6bac71839..fad090141 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -53,8 +53,8 @@ def clean_dataset_task( index_struct=index_struct, collection_binding_id=collection_binding_id, ) - documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() + documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all() + segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all() if documents is None or len(documents) == 0: logging.info(click.style("No documents found for dataset: {}".format(dataset_id), fg="green")) @@ -72,7 +72,7 @@ def clean_dataset_task( for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) for upload_file_id in image_upload_file_ids: - image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() if image_file is None: continue try: @@ -85,12 +85,12 @@ def clean_dataset_task( db.session.delete(image_file) db.session.delete(segment) - db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete() - db.session.query(DatasetQuery).filter(DatasetQuery.dataset_id == dataset_id).delete() - db.session.query(AppDatasetJoin).filter(AppDatasetJoin.dataset_id == dataset_id).delete() + db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete() + db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete() + db.session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete() # delete dataset metadata - db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id == dataset_id).delete() - db.session.query(DatasetMetadataBinding).filter(DatasetMetadataBinding.dataset_id == dataset_id).delete() + db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete() + db.session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete() # delete files if documents: for document in documents: @@ -102,7 +102,7 @@ def clean_dataset_task( file_id = data_source_info["upload_file_id"] file = ( db.session.query(UploadFile) - .filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) + .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) .first() ) if not file: diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index c72a3319c..dd7a544ff 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -28,12 +28,12 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i start_at = time.perf_counter() try: - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: raise Exception("Document has no dataset") - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() # check segment is exist if segments: index_node_ids = [segment.index_node_id for segment in segments] @@ -43,7 +43,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) for upload_file_id in image_upload_file_ids: - image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() if image_file is None: continue try: @@ -58,7 +58,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i db.session.commit() if file_id: - file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() if file: try: storage.delete(file.key) @@ -68,7 +68,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i db.session.commit() # delete dataset metadata binding - db.session.query(DatasetMetadataBinding).filter( + db.session.query(DatasetMetadataBinding).where( DatasetMetadataBinding.dataset_id == dataset_id, DatasetMetadataBinding.document_id == document_id, ).delete() diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index 1087a3776..0f72f87f1 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -24,17 +24,17 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): start_at = time.perf_counter() try: - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: raise Exception("Document has no dataset") index_type = dataset.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() for document_id in document_ids: - document = db.session.query(Document).filter(Document.id == document_id).first() + document = db.session.query(Document).where(Document.id == document_id).first() db.session.delete(document) - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() index_node_ids = [segment.index_node_id for segment in segments] index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index 5710d660b..5eda24674 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -24,7 +24,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] logging.info(click.style("Start create segment to index: {}".format(segment_id), fg="green")) start_at = time.perf_counter() - segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() + segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() if not segment: logging.info(click.style("Segment not found: {}".format(segment_id), fg="red")) db.session.close() diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index a27207f2f..7478bf5a9 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -35,7 +35,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): elif action == "add": dataset_documents = ( db.session.query(DatasetDocument) - .filter( + .where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, @@ -46,7 +46,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): if dataset_documents: dataset_documents_ids = [doc.id for doc in dataset_documents] - db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( {"indexing_status": "indexing"}, synchronize_session=False ) db.session.commit() @@ -56,7 +56,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): # add from vector index segments = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) .order_by(DocumentSegment.position.asc()) .all() ) @@ -76,19 +76,19 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): documents.append(document) # save vector index index_processor.load(dataset, documents, with_keywords=False) - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "completed"}, synchronize_session=False ) db.session.commit() except Exception as e: - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "error", "error": str(e)}, synchronize_session=False ) db.session.commit() elif action == "update": dataset_documents = ( db.session.query(DatasetDocument) - .filter( + .where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, @@ -100,7 +100,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): if dataset_documents: # update document status dataset_documents_ids = [doc.id for doc in dataset_documents] - db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( {"indexing_status": "indexing"}, synchronize_session=False ) db.session.commit() @@ -113,7 +113,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): try: segments = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) .order_by(DocumentSegment.position.asc()) .all() ) @@ -148,12 +148,12 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): documents.append(document) # save vector index index_processor.load(dataset, documents, with_keywords=False) - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "completed"}, synchronize_session=False ) db.session.commit() except Exception as e: - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "error", "error": str(e)}, synchronize_session=False ) db.session.commit() diff --git a/api/tasks/delete_account_task.py b/api/tasks/delete_account_task.py index 52c884ca2..d3b33e305 100644 --- a/api/tasks/delete_account_task.py +++ b/api/tasks/delete_account_task.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") def delete_account_task(account_id): - account = db.session.query(Account).filter(Account.id == account_id).first() + account = db.session.query(Account).where(Account.id == account_id).first() try: BillingService.delete_account(account_id) except Exception as e: diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index a93babc31..66ff0f9a0 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -22,11 +22,11 @@ def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, docume logging.info(click.style("Start delete segment from index", fg="green")) start_at = time.perf_counter() try: - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: return - dataset_document = db.session.query(Document).filter(Document.id == document_id).first() + dataset_document = db.session.query(Document).where(Document.id == document_id).first() if not dataset_document: return diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index 327eed472..e67ba5c76 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -21,7 +21,7 @@ def disable_segment_from_index_task(segment_id: str): logging.info(click.style("Start disable segment from index: {}".format(segment_id), fg="green")) start_at = time.perf_counter() - segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() + segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() if not segment: logging.info(click.style("Segment not found: {}".format(segment_id), fg="red")) db.session.close() diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py index 8b77b290c..0c8b1aabc 100644 --- a/api/tasks/disable_segments_from_index_task.py +++ b/api/tasks/disable_segments_from_index_task.py @@ -23,13 +23,13 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen """ start_at = time.perf_counter() - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan")) db.session.close() return - dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() + dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() if not dataset_document: logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan")) @@ -44,7 +44,7 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset_id, DocumentSegment.document_id == document_id, @@ -64,7 +64,7 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen logging.info(click.style("Segments removed from index latency: {}".format(end_at - start_at), fg="green")) except Exception: # update segment error msg - db.session.query(DocumentSegment).filter( + db.session.query(DocumentSegment).where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset_id, DocumentSegment.document_id == document_id, diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index b4848be19..dcc748ef1 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -25,7 +25,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): logging.info(click.style("Start sync document: {}".format(document_id), fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: logging.info(click.style("Document not found: {}".format(document_id), fg="red")) @@ -46,7 +46,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): page_edited_time = data_source_info["last_edited_time"] data_source_binding = ( db.session.query(DataSourceOauthBinding) - .filter( + .where( db.and_( DataSourceOauthBinding.tenant_id == document.tenant_id, DataSourceOauthBinding.provider == "notion", @@ -77,13 +77,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): # delete all document segment and index try: - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: raise Exception("Dataset not found") index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index a85aab0bb..ec6d10d93 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -24,7 +24,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): documents = [] start_at = time.perf_counter() - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: logging.info(click.style("Dataset is not found: {}".format(dataset_id), fg="yellow")) db.session.close() @@ -48,7 +48,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): except Exception as e: for document_id in document_ids: document = ( - db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) if document: document.indexing_status = "error" @@ -63,7 +63,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): logging.info(click.style("Start process document: {}".format(document_id), fg="green")) document = ( - db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) if document: diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 167b928f5..e53c38ddc 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -23,7 +23,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): logging.info(click.style("Start update document: {}".format(document_id), fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: logging.info(click.style("Document not found: {}".format(document_id), fg="red")) @@ -36,14 +36,14 @@ def document_indexing_update_task(dataset_id: str, document_id: str): # delete all document segment and index try: - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: raise Exception("Dataset not found") index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index a6c93e110..b3ddface5 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -25,7 +25,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): documents = [] start_at = time.perf_counter() - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if dataset is None: logging.info(click.style("Dataset not found: {}".format(dataset_id), fg="red")) db.session.close() @@ -50,7 +50,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): except Exception as e: for document_id in document_ids: document = ( - db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) if document: document.indexing_status = "error" @@ -66,7 +66,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): logging.info(click.style("Start process document: {}".format(document_id), fg="green")) document = ( - db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) if document: @@ -74,7 +74,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 21f08f40a..13822f078 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -24,7 +24,7 @@ def enable_segment_to_index_task(segment_id: str): logging.info(click.style("Start enable segment to index: {}".format(segment_id), fg="green")) start_at = time.perf_counter() - segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() + segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() if not segment: logging.info(click.style("Segment not found: {}".format(segment_id), fg="red")) db.session.close() diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py index 625a3b582..e3fdf04d8 100644 --- a/api/tasks/enable_segments_to_index_task.py +++ b/api/tasks/enable_segments_to_index_task.py @@ -25,12 +25,12 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i Usage: enable_segments_to_index_task.delay(segment_ids, dataset_id, document_id) """ start_at = time.perf_counter() - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan")) return - dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() + dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() if not dataset_document: logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan")) @@ -45,7 +45,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i segments = ( db.session.query(DocumentSegment) - .filter( + .where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset_id, DocumentSegment.document_id == document_id, @@ -95,7 +95,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i except Exception as e: logging.exception("enable segments to index failed") # update segment error msg - db.session.query(DocumentSegment).filter( + db.session.query(DocumentSegment).where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset_id, DocumentSegment.document_id == document_id, diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index e7d49c78d..dfb238957 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -21,7 +21,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): logging.info(click.style("Recover document: {}".format(document_id), fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: logging.info(click.style("Document not found: {}".format(document_id), fg="red")) diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 179adcbd6..1619f8c54 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -76,7 +76,7 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): def _delete_app_model_configs(tenant_id: str, app_id: str): def del_model_config(model_config_id: str): - db.session.query(AppModelConfig).filter(AppModelConfig.id == model_config_id).delete(synchronize_session=False) + db.session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False) _delete_records( """select id from app_model_configs where app_id=:app_id limit 1000""", @@ -88,14 +88,14 @@ def _delete_app_model_configs(tenant_id: str, app_id: str): def _delete_app_site(tenant_id: str, app_id: str): def del_site(site_id: str): - db.session.query(Site).filter(Site.id == site_id).delete(synchronize_session=False) + db.session.query(Site).where(Site.id == site_id).delete(synchronize_session=False) _delete_records("""select id from sites where app_id=:app_id limit 1000""", {"app_id": app_id}, del_site, "site") def _delete_app_mcp_servers(tenant_id: str, app_id: str): def del_mcp_server(mcp_server_id: str): - db.session.query(AppMCPServer).filter(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False) + db.session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False) _delete_records( """select id from app_mcp_servers where app_id=:app_id limit 1000""", @@ -107,7 +107,7 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str): def _delete_app_api_tokens(tenant_id: str, app_id: str): def del_api_token(api_token_id: str): - db.session.query(ApiToken).filter(ApiToken.id == api_token_id).delete(synchronize_session=False) + db.session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False) _delete_records( """select id from api_tokens where app_id=:app_id limit 1000""", {"app_id": app_id}, del_api_token, "api token" @@ -116,7 +116,7 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str): def _delete_installed_apps(tenant_id: str, app_id: str): def del_installed_app(installed_app_id: str): - db.session.query(InstalledApp).filter(InstalledApp.id == installed_app_id).delete(synchronize_session=False) + db.session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False) _delete_records( """select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -128,7 +128,7 @@ def _delete_installed_apps(tenant_id: str, app_id: str): def _delete_recommended_apps(tenant_id: str, app_id: str): def del_recommended_app(recommended_app_id: str): - db.session.query(RecommendedApp).filter(RecommendedApp.id == recommended_app_id).delete( + db.session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete( synchronize_session=False ) @@ -142,9 +142,9 @@ def _delete_recommended_apps(tenant_id: str, app_id: str): def _delete_app_annotation_data(tenant_id: str, app_id: str): def del_annotation_hit_history(annotation_hit_history_id: str): - db.session.query(AppAnnotationHitHistory).filter( - AppAnnotationHitHistory.id == annotation_hit_history_id - ).delete(synchronize_session=False) + db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete( + synchronize_session=False + ) _delete_records( """select id from app_annotation_hit_histories where app_id=:app_id limit 1000""", @@ -154,7 +154,7 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str): ) def del_annotation_setting(annotation_setting_id: str): - db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.id == annotation_setting_id).delete( + db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete( synchronize_session=False ) @@ -168,7 +168,7 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str): def _delete_app_dataset_joins(tenant_id: str, app_id: str): def del_dataset_join(dataset_join_id: str): - db.session.query(AppDatasetJoin).filter(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False) + db.session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False) _delete_records( """select id from app_dataset_joins where app_id=:app_id limit 1000""", @@ -180,7 +180,7 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str): def _delete_app_workflows(tenant_id: str, app_id: str): def del_workflow(workflow_id: str): - db.session.query(Workflow).filter(Workflow.id == workflow_id).delete(synchronize_session=False) + db.session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False) _delete_records( """select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -220,7 +220,7 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def del_workflow_app_log(workflow_app_log_id: str): - db.session.query(WorkflowAppLog).filter(WorkflowAppLog.id == workflow_app_log_id).delete( + db.session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete( synchronize_session=False ) @@ -234,10 +234,10 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def _delete_app_conversations(tenant_id: str, app_id: str): def del_conversation(conversation_id: str): - db.session.query(PinnedConversation).filter(PinnedConversation.conversation_id == conversation_id).delete( + db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete( synchronize_session=False ) - db.session.query(Conversation).filter(Conversation.id == conversation_id).delete(synchronize_session=False) + db.session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False) _delete_records( """select id from conversations where app_id=:app_id limit 1000""", @@ -257,19 +257,19 @@ def _delete_conversation_variables(*, app_id: str): def _delete_app_messages(tenant_id: str, app_id: str): def del_message(message_id: str): - db.session.query(MessageFeedback).filter(MessageFeedback.message_id == message_id).delete( + db.session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete( synchronize_session=False ) - db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == message_id).delete( + db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete( synchronize_session=False ) - db.session.query(MessageChain).filter(MessageChain.message_id == message_id).delete(synchronize_session=False) - db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message_id).delete( + db.session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False) + db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete( synchronize_session=False ) - db.session.query(MessageFile).filter(MessageFile.message_id == message_id).delete(synchronize_session=False) - db.session.query(SavedMessage).filter(SavedMessage.message_id == message_id).delete(synchronize_session=False) - db.session.query(Message).filter(Message.id == message_id).delete() + db.session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False) + db.session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False) + db.session.query(Message).where(Message.id == message_id).delete() _delete_records( """select id from messages where app_id=:app_id limit 1000""", {"app_id": app_id}, del_message, "message" @@ -278,7 +278,7 @@ def _delete_app_messages(tenant_id: str, app_id: str): def _delete_workflow_tool_providers(tenant_id: str, app_id: str): def del_tool_provider(tool_provider_id: str): - db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.id == tool_provider_id).delete( + db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete( synchronize_session=False ) @@ -292,7 +292,7 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str): def _delete_app_tag_bindings(tenant_id: str, app_id: str): def del_tag_binding(tag_binding_id: str): - db.session.query(TagBinding).filter(TagBinding.id == tag_binding_id).delete(synchronize_session=False) + db.session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False) _delete_records( """select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""", @@ -304,7 +304,7 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str): def _delete_end_users(tenant_id: str, app_id: str): def del_end_user(end_user_id: str): - db.session.query(EndUser).filter(EndUser.id == end_user_id).delete(synchronize_session=False) + db.session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False) _delete_records( """select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -316,7 +316,7 @@ def _delete_end_users(tenant_id: str, app_id: str): def _delete_trace_app_configs(tenant_id: str, app_id: str): def del_trace_app_config(trace_app_config_id: str): - db.session.query(TraceAppConfig).filter(TraceAppConfig.id == trace_app_config_id).delete( + db.session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete( synchronize_session=False ) diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index 0e2960788..3f73cc7b4 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -22,7 +22,7 @@ def remove_document_from_index_task(document_id: str): logging.info(click.style("Start remove document segments from index: {}".format(document_id), fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).filter(Document.id == document_id).first() + document = db.session.query(Document).where(Document.id == document_id).first() if not document: logging.info(click.style("Document not found: {}".format(document_id), fg="red")) db.session.close() @@ -43,7 +43,7 @@ def remove_document_from_index_task(document_id: str): index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all() + segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).all() index_node_ids = [segment.index_node_id for segment in segments] if index_node_ids: try: @@ -51,7 +51,7 @@ def remove_document_from_index_task(document_id: str): except Exception: logging.exception(f"clean dataset {dataset.id} from index failed") # update segment to disable - db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).update( + db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update( { DocumentSegment.enabled: False, DocumentSegment.disabled_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 8f8c3f9d8..58f0156af 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -25,7 +25,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): documents: list[Document] = [] start_at = time.perf_counter() - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: logging.info(click.style("Dataset not found: {}".format(dataset_id), fg="red")) db.session.close() @@ -45,7 +45,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): ) except Exception as e: document = ( - db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) if document: document.indexing_status = "error" @@ -59,7 +59,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): logging.info(click.style("Start retry document: {}".format(document_id), fg="green")) document = ( - db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) if not document: logging.info(click.style("Document not found: {}".format(document_id), fg="yellow")) @@ -69,7 +69,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): # clean old data index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index dba0a39c2..539c2db80 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -24,7 +24,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): """ start_at = time.perf_counter() - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if dataset is None: raise ValueError("Dataset not found") @@ -41,7 +41,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): ) except Exception as e: document = ( - db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) if document: document.indexing_status = "error" @@ -53,7 +53,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): return logging.info(click.style("Start sync website document: {}".format(document_id), fg="green")) - document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: logging.info(click.style("Document not found: {}".format(document_id), fg="yellow")) return @@ -61,7 +61,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): # clean old data index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index 30cd2e60c..e96d70c4a 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -214,7 +214,7 @@ class TestDraftVariableLoader(unittest.TestCase): def tearDown(self): with Session(bind=db.engine, expire_on_commit=False) as session: - session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.app_id == self._test_app_id).delete( + session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id).delete( synchronize_session=False ) session.commit() diff --git a/api/tests/unit_tests/core/helper/test_encrypter.py b/api/tests/unit_tests/core/helper/test_encrypter.py index 61cf8f255..589000974 100644 --- a/api/tests/unit_tests/core/helper/test_encrypter.py +++ b/api/tests/unit_tests/core/helper/test_encrypter.py @@ -44,7 +44,7 @@ class TestEncryptToken: """Test successful token encryption""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.filter.return_value.first.return_value = mock_tenant + mock_query.return_value.where.return_value.first.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_data" result = encrypt_token("tenant-123", "test_token") @@ -55,7 +55,7 @@ class TestEncryptToken: @patch("models.engine.db.session.query") def test_tenant_not_found(self, mock_query): """Test error when tenant doesn't exist""" - mock_query.return_value.filter.return_value.first.return_value = None + mock_query.return_value.where.return_value.first.return_value = None with pytest.raises(ValueError) as exc_info: encrypt_token("invalid-tenant", "test_token") @@ -127,7 +127,7 @@ class TestEncryptDecryptIntegration: # Setup mock tenant mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.filter.return_value.first.return_value = mock_tenant + mock_query.return_value.where.return_value.first.return_value = mock_tenant # Setup mock encryption/decryption original_token = "test_token_123" @@ -153,7 +153,7 @@ class TestSecurity: # Setup mock tenant mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "tenant1_public_key" - mock_query.return_value.filter.return_value.first.return_value = mock_tenant + mock_query.return_value.where.return_value.first.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_for_tenant1" # Encrypt token for tenant1 @@ -186,7 +186,7 @@ class TestSecurity: def test_encryption_randomness(self, mock_encrypt, mock_query): """Ensure same plaintext produces different ciphertext""" mock_tenant = MagicMock(encrypt_public_key="key") - mock_query.return_value.filter.return_value.first.return_value = mock_tenant + mock_query.return_value.where.return_value.first.return_value = mock_tenant # Different outputs for same input mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"] @@ -211,7 +211,7 @@ class TestEdgeCases: """Test encryption of empty token""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.filter.return_value.first.return_value = mock_tenant + mock_query.return_value.where.return_value.first.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_empty" result = encrypt_token("tenant-123", "") @@ -225,7 +225,7 @@ class TestEdgeCases: """Test tokens containing special/unicode characters""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.filter.return_value.first.return_value = mock_tenant + mock_query.return_value.where.return_value.first.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_special" # Test various special characters @@ -248,7 +248,7 @@ class TestEdgeCases: """Test behavior when token exceeds RSA encryption limits""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.filter.return_value.first.return_value = mock_tenant + mock_query.return_value.where.return_value.first.return_value = mock_tenant # RSA 2048-bit can only encrypt ~245 bytes # The actual limit depends on padding scheme diff --git a/api/tests/unit_tests/factories/test_build_from_mapping.py b/api/tests/unit_tests/factories/test_build_from_mapping.py index 48463a369..d42c4412f 100644 --- a/api/tests/unit_tests/factories/test_build_from_mapping.py +++ b/api/tests/unit_tests/factories/test_build_from_mapping.py @@ -54,8 +54,7 @@ def mock_tool_file(): mock.mimetype = "application/pdf" mock.original_url = "http://example.com/tool.pdf" mock.size = 2048 - with patch("factories.file_factory.db.session.query") as mock_query: - mock_query.return_value.filter.return_value.first.return_value = mock + with patch("factories.file_factory.db.session.scalar", return_value=mock): yield mock @@ -153,8 +152,7 @@ def test_build_from_remote_url(mock_http_head): def test_tool_file_not_found(): """Test ToolFile not found in database.""" - with patch("factories.file_factory.db.session.query") as mock_query: - mock_query.return_value.filter.return_value.first.return_value = None + with patch("factories.file_factory.db.session.scalar", return_value=None): mapping = tool_file_mapping() with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"): build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) diff --git a/api/tests/unit_tests/models/test_types_enum_text.py b/api/tests/unit_tests/models/test_types_enum_text.py index 908b5a536..e4061b72c 100644 --- a/api/tests/unit_tests/models/test_types_enum_text.py +++ b/api/tests/unit_tests/models/test_types_enum_text.py @@ -114,12 +114,12 @@ class TestEnumText: session.commit() with Session(engine) as session: - user = session.query(_User).filter(_User.id == admin_user_id).first() + user = session.query(_User).where(_User.id == admin_user_id).first() assert user.user_type == _UserType.admin assert user.user_type_nullable is None with Session(engine) as session: - user = session.query(_User).filter(_User.id == normal_user_id).first() + user = session.query(_User).where(_User.id == normal_user_id).first() assert user.user_type == _UserType.normal assert user.user_type_nullable == _UserType.normal @@ -188,4 +188,4 @@ class TestEnumText: with pytest.raises(ValueError) as exc: with Session(engine) as session: - _user = session.query(_User).filter(_User.id == 1).first() + _user = session.query(_User).where(_User.id == 1).first() diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_service.py b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py index f0e425e74..dc42a04cf 100644 --- a/api/tests/unit_tests/services/auth/test_api_key_auth_service.py +++ b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py @@ -28,7 +28,7 @@ class TestApiKeyAuthService: mock_binding.provider = self.provider mock_binding.disabled = False - mock_session.query.return_value.filter.return_value.all.return_value = [mock_binding] + mock_session.query.return_value.where.return_value.all.return_value = [mock_binding] result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id) @@ -39,7 +39,7 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.db.session") def test_get_provider_auth_list_empty(self, mock_session): """Test get provider auth list - empty result""" - mock_session.query.return_value.filter.return_value.all.return_value = [] + mock_session.query.return_value.where.return_value.all.return_value = [] result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id) @@ -48,13 +48,13 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.db.session") def test_get_provider_auth_list_filters_disabled(self, mock_session): """Test get provider auth list - filters disabled items""" - mock_session.query.return_value.filter.return_value.all.return_value = [] + mock_session.query.return_value.where.return_value.all.return_value = [] ApiKeyAuthService.get_provider_auth_list(self.tenant_id) - # Verify filter conditions include disabled.is_(False) - filter_call = mock_session.query.return_value.filter.call_args[0] - assert len(filter_call) == 2 # tenant_id and disabled filter conditions + # Verify where conditions include disabled.is_(False) + where_call = mock_session.query.return_value.where.call_args[0] + assert len(where_call) == 2 # tenant_id and disabled filter conditions @patch("services.auth.api_key_auth_service.db.session") @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") @@ -138,7 +138,8 @@ class TestApiKeyAuthService: # Mock database query result mock_binding = Mock() mock_binding.credentials = json.dumps(self.mock_credentials) - mock_session.query.return_value.filter.return_value.first.return_value = mock_binding + mock_session.query.return_value.where.return_value.first.return_value = mock_binding + mock_session.query.return_value.where.return_value.first.return_value = mock_binding result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider) @@ -148,7 +149,7 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.db.session") def test_get_auth_credentials_not_found(self, mock_session): """Test get auth credentials - not found""" - mock_session.query.return_value.filter.return_value.first.return_value = None + mock_session.query.return_value.where.return_value.first.return_value = None result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider) @@ -157,13 +158,13 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.db.session") def test_get_auth_credentials_filters_correctly(self, mock_session): """Test get auth credentials - applies correct filters""" - mock_session.query.return_value.filter.return_value.first.return_value = None + mock_session.query.return_value.where.return_value.first.return_value = None ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider) - # Verify filter conditions are correct - filter_call = mock_session.query.return_value.filter.call_args[0] - assert len(filter_call) == 4 # tenant_id, category, provider, disabled + # Verify where conditions are correct + where_call = mock_session.query.return_value.where.call_args[0] + assert len(where_call) == 4 # tenant_id, category, provider, disabled @patch("services.auth.api_key_auth_service.db.session") def test_get_auth_credentials_json_parsing(self, mock_session): @@ -173,7 +174,7 @@ class TestApiKeyAuthService: mock_binding = Mock() mock_binding.credentials = json.dumps(special_credentials, ensure_ascii=False) - mock_session.query.return_value.filter.return_value.first.return_value = mock_binding + mock_session.query.return_value.where.return_value.first.return_value = mock_binding result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider) @@ -185,7 +186,7 @@ class TestApiKeyAuthService: """Test delete provider auth - success scenario""" # Mock database query result mock_binding = Mock() - mock_session.query.return_value.filter.return_value.first.return_value = mock_binding + mock_session.query.return_value.where.return_value.first.return_value = mock_binding ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id) @@ -196,7 +197,7 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.db.session") def test_delete_provider_auth_not_found(self, mock_session): """Test delete provider auth - not found""" - mock_session.query.return_value.filter.return_value.first.return_value = None + mock_session.query.return_value.where.return_value.first.return_value = None ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id) @@ -207,13 +208,13 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.db.session") def test_delete_provider_auth_filters_by_tenant(self, mock_session): """Test delete provider auth - filters by tenant""" - mock_session.query.return_value.filter.return_value.first.return_value = None + mock_session.query.return_value.where.return_value.first.return_value = None ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id) - # Verify filter conditions include tenant_id and binding_id - filter_call = mock_session.query.return_value.filter.call_args[0] - assert len(filter_call) == 2 + # Verify where conditions include tenant_id and binding_id + where_call = mock_session.query.return_value.where.call_args[0] + assert len(where_call) == 2 def test_validate_api_key_auth_args_success(self): """Test API key auth args validation - success scenario""" @@ -336,7 +337,7 @@ class TestApiKeyAuthService: # Mock database returning invalid JSON mock_binding = Mock() mock_binding.credentials = "invalid json content" - mock_session.query.return_value.filter.return_value.first.return_value = mock_binding + mock_session.query.return_value.where.return_value.first.return_value = mock_binding with pytest.raises(json.JSONDecodeError): ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider) diff --git a/api/tests/unit_tests/services/auth/test_auth_integration.py b/api/tests/unit_tests/services/auth/test_auth_integration.py index 31a617345..4ce552594 100644 --- a/api/tests/unit_tests/services/auth/test_auth_integration.py +++ b/api/tests/unit_tests/services/auth/test_auth_integration.py @@ -63,10 +63,10 @@ class TestAuthIntegration: tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials) tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials) - mock_session.query.return_value.filter.return_value.all.return_value = [tenant1_binding] + mock_session.query.return_value.where.return_value.all.return_value = [tenant1_binding] result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1) - mock_session.query.return_value.filter.return_value.all.return_value = [tenant2_binding] + mock_session.query.return_value.where.return_value.all.return_value = [tenant2_binding] result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2) assert len(result1) == 1 @@ -77,7 +77,7 @@ class TestAuthIntegration: @patch("services.auth.api_key_auth_service.db.session") def test_cross_tenant_access_prevention(self, mock_session): """Test prevention of cross-tenant credential access""" - mock_session.query.return_value.filter.return_value.first.return_value = None + mock_session.query.return_value.where.return_value.first.return_value = None result = ApiKeyAuthService.get_auth_credentials(self.tenant_id_2, self.category, AuthType.FIRECRAWL) diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 13900ab6d..442839e44 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -708,9 +708,9 @@ class TestTenantService: with patch("services.account_service.db") as mock_db: # Mock the join query that returns the tenant_account_join mock_query = MagicMock() - mock_filter = MagicMock() - mock_filter.first.return_value = mock_tenant_join - mock_query.filter.return_value = mock_filter + mock_where = MagicMock() + mock_where.first.return_value = mock_tenant_join + mock_query.where.return_value = mock_where mock_query.join.return_value = mock_query mock_db.session.query.return_value = mock_query @@ -1381,10 +1381,10 @@ class TestRegisterService: # Mock database queries - complex query mocking mock_query1 = MagicMock() - mock_query1.filter.return_value.first.return_value = mock_tenant + mock_query1.where.return_value.first.return_value = mock_tenant mock_query2 = MagicMock() - mock_query2.join.return_value.filter.return_value.first.return_value = (mock_account, "normal") + mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal") mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2] @@ -1449,7 +1449,7 @@ class TestRegisterService: mock_query1.filter.return_value.first.return_value = mock_tenant mock_query2 = MagicMock() - mock_query2.join.return_value.filter.return_value.first.return_value = None # No account found + mock_query2.join.return_value.where.return_value.first.return_value = None # No account found mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2] @@ -1482,7 +1482,7 @@ class TestRegisterService: mock_query1.filter.return_value.first.return_value = mock_tenant mock_query2 = MagicMock() - mock_query2.join.return_value.filter.return_value.first.return_value = (mock_account, "normal") + mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal") mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2] diff --git a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py b/api/tests/unit_tests/services/workflow/test_workflow_deletion.py index 2c87eaf80..dfe325648 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_deletion.py @@ -43,7 +43,7 @@ def test_delete_workflow_success(workflow_setup): # Setup mocks # Mock the tool provider query to return None (not published as a tool) - workflow_setup["session"].query.return_value.filter.return_value.first.return_value = None + workflow_setup["session"].query.return_value.where.return_value.first.return_value = None workflow_setup["session"].scalar = MagicMock( side_effect=[workflow_setup["workflow"], None] @@ -106,7 +106,7 @@ def test_delete_workflow_published_as_tool_error(workflow_setup): # Mock the tool provider query mock_tool_provider = MagicMock(spec=WorkflowToolProvider) - workflow_setup["session"].query.return_value.filter.return_value.first.return_value = mock_tool_provider + workflow_setup["session"].query.return_value.where.return_value.first.return_value = mock_tool_provider workflow_setup["session"].scalar = MagicMock( side_effect=[workflow_setup["workflow"], None]