From 24e2b72b716326f34ecb234ff467d3e26a24759e Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Sun, 31 Aug 2025 18:03:51 +0900 Subject: [PATCH] Update ast-grep pattern for session.query (#24828) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .github/workflows/autofix.yml | 1 + api/controllers/console/app/message.py | 2 +- api/schedule/check_upgradable_plugin_task.py | 2 +- .../clean_workflow_runlogs_precise.py | 2 +- api/services/annotation_service.py | 4 ++-- .../clear_free_plan_tenant_expired_logs.py | 12 +++++------ api/services/dataset_service.py | 2 +- .../plugin/plugin_auto_upgrade_service.py | 6 +++--- .../services/test_annotation_service.py | 2 +- .../services/test_app_dsl_service.py | 6 +++--- ...est_clear_free_plan_tenant_expired_logs.py | 20 +++++++++---------- 11 files changed, 30 insertions(+), 29 deletions(-) diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 65f413af8..82ba95444 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -26,6 +26,7 @@ jobs: - name: ast-grep run: | uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all + uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all - name: mdformat run: | uvx mdformat . diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index fd86191a0..f0605a37f 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -130,7 +130,7 @@ class MessageFeedbackApi(Resource): message_id = str(args["message_id"]) - message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() + message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() if not message: raise NotFound("Message Not Exists.") diff --git a/api/schedule/check_upgradable_plugin_task.py b/api/schedule/check_upgradable_plugin_task.py index e27391b55..08a5cfce7 100644 --- a/api/schedule/check_upgradable_plugin_task.py +++ b/api/schedule/check_upgradable_plugin_task.py @@ -20,7 +20,7 @@ def check_upgradable_plugin_task(): strategies = ( db.session.query(TenantPluginAutoUpgradeStrategy) - .filter( + .where( TenantPluginAutoUpgradeStrategy.upgrade_time_of_day >= now_seconds_of_day, TenantPluginAutoUpgradeStrategy.upgrade_time_of_day < now_seconds_of_day + AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL, diff --git a/api/schedule/clean_workflow_runlogs_precise.py b/api/schedule/clean_workflow_runlogs_precise.py index 75057983f..1a0362ec3 100644 --- a/api/schedule/clean_workflow_runlogs_precise.py +++ b/api/schedule/clean_workflow_runlogs_precise.py @@ -93,7 +93,7 @@ def _delete_batch_with_retry(workflow_run_ids: list[str], attempt_count: int) -> with db.session.begin_nested(): message_data = ( db.session.query(Message.id, Message.conversation_id) - .filter(Message.workflow_run_id.in_(workflow_run_ids)) + .where(Message.workflow_run_id.in_(workflow_run_ids)) .all() ) message_id_list = [msg.id for msg in message_data] diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 45b246af1..6603063c2 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -282,7 +282,7 @@ class AppAnnotationService: annotations_to_delete = ( db.session.query(MessageAnnotation, AppAnnotationSetting) .outerjoin(AppAnnotationSetting, MessageAnnotation.app_id == AppAnnotationSetting.app_id) - .filter(MessageAnnotation.id.in_(annotation_ids)) + .where(MessageAnnotation.id.in_(annotation_ids)) .all() ) @@ -493,7 +493,7 @@ class AppAnnotationService: def clear_all_annotations(cls, app_id: str) -> dict: app = ( db.session.query(App) - .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .first() ) diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index b28afcaa4..de00e7463 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -62,7 +62,7 @@ class ClearFreePlanTenantExpiredLogs: # Query records related to expired messages records = ( session.query(model) - .filter( + .where( model.message_id.in_(batch_message_ids), # type: ignore ) .all() @@ -101,7 +101,7 @@ class ClearFreePlanTenantExpiredLogs: except Exception: logger.exception("Failed to save %s records", table_name) - session.query(model).filter( + session.query(model).where( model.id.in_(record_ids), # type: ignore ).delete(synchronize_session=False) @@ -295,7 +295,7 @@ class ClearFreePlanTenantExpiredLogs: with Session(db.engine).no_autoflush as session: workflow_app_logs = ( session.query(WorkflowAppLog) - .filter( + .where( WorkflowAppLog.tenant_id == tenant_id, WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days), ) @@ -321,9 +321,9 @@ class ClearFreePlanTenantExpiredLogs: workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs] # delete workflow app logs - session.query(WorkflowAppLog).filter( - WorkflowAppLog.id.in_(workflow_app_log_ids), - ).delete(synchronize_session=False) + session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete( + synchronize_session=False + ) session.commit() click.echo( diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 84860fd17..bbebb7a92 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -2346,7 +2346,7 @@ class SegmentService: def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): segments = ( db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count) - .filter( + .where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, DocumentSegment.document_id == document.id, diff --git a/api/services/plugin/plugin_auto_upgrade_service.py b/api/services/plugin/plugin_auto_upgrade_service.py index 377405044..174bed488 100644 --- a/api/services/plugin/plugin_auto_upgrade_service.py +++ b/api/services/plugin/plugin_auto_upgrade_service.py @@ -10,7 +10,7 @@ class PluginAutoUpgradeService: with Session(db.engine) as session: return ( session.query(TenantPluginAutoUpgradeStrategy) - .filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) + .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) .first() ) @@ -26,7 +26,7 @@ class PluginAutoUpgradeService: with Session(db.engine) as session: exist_strategy = ( session.query(TenantPluginAutoUpgradeStrategy) - .filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) + .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) .first() ) if not exist_strategy: @@ -54,7 +54,7 @@ class PluginAutoUpgradeService: with Session(db.engine) as session: exist_strategy = ( session.query(TenantPluginAutoUpgradeStrategy) - .filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) + .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) .first() ) if not exist_strategy: diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index 92d93d601..418442088 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -674,7 +674,7 @@ class TestAnnotationService: history = ( db.session.query(AppAnnotationHitHistory) - .filter( + .where( AppAnnotationHitHistory.annotation_id == annotation.id, AppAnnotationHitHistory.message_id == message_id ) .first() diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index fc614b229..d83983d0f 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -166,7 +166,7 @@ class TestAppDslService: assert result.imported_dsl_version == "" # Verify no app was created in database - apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count() + apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count() assert apps_count == 1 # Only the original test app def test_import_app_missing_yaml_url(self, db_session_with_containers, mock_external_service_dependencies): @@ -191,7 +191,7 @@ class TestAppDslService: assert result.imported_dsl_version == "" # Verify no app was created in database - apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count() + apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count() assert apps_count == 1 # Only the original test app def test_import_app_invalid_import_mode(self, db_session_with_containers, mock_external_service_dependencies): @@ -215,7 +215,7 @@ class TestAppDslService: ) # Verify no app was created in database - apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count() + apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count() assert apps_count == 1 # Only the original test app def test_export_dsl_chat_app_success(self, db_session_with_containers, mock_external_service_dependencies): diff --git a/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py index dd2bc2181..5099362e0 100644 --- a/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py +++ b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py @@ -57,7 +57,7 @@ class TestClearFreePlanTenantExpiredLogs: def test_clear_message_related_tables_no_records_found(self, mock_session, sample_message_ids): """Test when no related records are found.""" with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: - mock_session.query.return_value.filter.return_value.all.return_value = [] + mock_session.query.return_value.where.return_value.all.return_value = [] ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) @@ -70,7 +70,7 @@ class TestClearFreePlanTenantExpiredLogs: ): """Test when records are found and have to_dict method.""" with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: - mock_session.query.return_value.filter.return_value.all.return_value = sample_records + mock_session.query.return_value.where.return_value.all.return_value = sample_records ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) @@ -101,7 +101,7 @@ class TestClearFreePlanTenantExpiredLogs: records.append(record) # Mock records for first table only, empty for others - mock_session.query.return_value.filter.return_value.all.side_effect = [ + mock_session.query.return_value.where.return_value.all.side_effect = [ records, [], [], @@ -123,13 +123,13 @@ class TestClearFreePlanTenantExpiredLogs: with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: mock_storage.save.side_effect = Exception("Storage error") - mock_session.query.return_value.filter.return_value.all.return_value = sample_records + mock_session.query.return_value.where.return_value.all.return_value = sample_records # Should not raise exception ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) # Should still delete records even if backup fails - assert mock_session.query.return_value.filter.return_value.delete.called + assert mock_session.query.return_value.where.return_value.delete.called def test_clear_message_related_tables_serialization_error_continues(self, mock_session, sample_message_ids): """Test that method continues even when record serialization fails.""" @@ -138,30 +138,30 @@ class TestClearFreePlanTenantExpiredLogs: record.id = "record-1" record.to_dict.side_effect = Exception("Serialization error") - mock_session.query.return_value.filter.return_value.all.return_value = [record] + mock_session.query.return_value.where.return_value.all.return_value = [record] # Should not raise exception ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) # Should still delete records even if serialization fails - assert mock_session.query.return_value.filter.return_value.delete.called + assert mock_session.query.return_value.where.return_value.delete.called def test_clear_message_related_tables_deletion_called(self, mock_session, sample_message_ids, sample_records): """Test that deletion is called for found records.""" with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: - mock_session.query.return_value.filter.return_value.all.return_value = sample_records + mock_session.query.return_value.where.return_value.all.return_value = sample_records ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) # Should call delete for each table that has records - assert mock_session.query.return_value.filter.return_value.delete.called + assert mock_session.query.return_value.where.return_value.delete.called def test_clear_message_related_tables_logging_output( self, mock_session, sample_message_ids, sample_records, capsys ): """Test that logging output is generated.""" with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: - mock_session.query.return_value.filter.return_value.all.return_value = sample_records + mock_session.query.return_value.where.return_value.all.return_value = sample_records ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)