From 70da81d0e5133bcf9d867385c5927b9f61432619 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 19 Aug 2025 14:41:52 +0900 Subject: [PATCH] try ast-grep (#24149) --- .github/workflows/autofix.yml | 3 ++ api/controllers/console/app/generator.py | 2 +- .../console/datasets/upload_file.py | 2 +- .../task_pipeline/message_cycle_manager.py | 2 +- api/core/llm_generator/llm_generator.py | 6 ++-- .../clean_workflow_runlogs_precise.py | 34 +++++++++---------- api/services/annotation_service.py | 8 ++--- .../services/test_annotation_service.py | 4 +-- .../test_api_based_extension_service.py | 2 +- .../services/test_message_service.py | 2 +- .../test_model_load_balancing_service.py | 2 +- 11 files changed, 35 insertions(+), 32 deletions(-) diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 152ff3b64..f5ba498c7 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -23,6 +23,9 @@ jobs: uv run ruff check --fix-only . # Format code uv run ruff format . + - name: ast-grep + run: | + uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all - uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27 diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 2a81d1b64..1cabfd9f2 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -137,7 +137,7 @@ class InstructionGenerateApi(Resource): from models import App, db from services.workflow_service import WorkflowService - app = db.session.query(App).filter(App.id == args["flow_id"]).first() + app = db.session.query(App).where(App.id == args["flow_id"]).first() if not app: return {"error": f"app {args['flow_id']} not found"}, 400 workflow = WorkflowService().get_draft_workflow(app_model=app) diff --git a/api/controllers/console/datasets/upload_file.py b/api/controllers/console/datasets/upload_file.py index 9b456c771..2afdaf7f2 100644 --- a/api/controllers/console/datasets/upload_file.py +++ b/api/controllers/console/datasets/upload_file.py @@ -39,7 +39,7 @@ class UploadFileApi(Resource): data_source_info = document.data_source_info_dict if data_source_info and "upload_file_id" in data_source_info: file_id = data_source_info["upload_file_id"] - upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() if not upload_file: raise NotFound("UploadFile not found.") else: diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index f3b9dbf75..0d786ba05 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -181,7 +181,7 @@ class MessageCycleManager: :param message_id: message id :return: """ - message_file = db.session.query(MessageFile).filter(MessageFile.id == message_id).first() + message_file = db.session.query(MessageFile).where(MessageFile.id == message_id).first() event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE return MessageStreamResponse( diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 64fc3a3e8..503f8e3e8 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -399,9 +399,9 @@ class LLMGenerator: def instruction_modify_legacy( tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None ) -> dict: - app: App | None = db.session.query(App).filter(App.id == flow_id).first() + app: App | None = db.session.query(App).where(App.id == flow_id).first() last_run: Message | None = ( - db.session.query(Message).filter(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() + db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() ) if not last_run: return LLMGenerator.__instruction_modify_common( @@ -442,7 +442,7 @@ class LLMGenerator: ) -> dict: from services.workflow_service import WorkflowService - app: App | None = db.session.query(App).filter(App.id == flow_id).first() + app: App | None = db.session.query(App).where(App.id == flow_id).first() if not app: raise ValueError("App not found.") workflow = WorkflowService().get_draft_workflow(app_model=app) diff --git a/api/schedule/clean_workflow_runlogs_precise.py b/api/schedule/clean_workflow_runlogs_precise.py index 0de3b5f68..8c21be01d 100644 --- a/api/schedule/clean_workflow_runlogs_precise.py +++ b/api/schedule/clean_workflow_runlogs_precise.py @@ -37,7 +37,7 @@ def clean_workflow_runlogs_precise(): cutoff_date = datetime.datetime.now() - datetime.timedelta(days=retention_days) try: - total_workflow_runs = db.session.query(WorkflowRun).filter(WorkflowRun.created_at < cutoff_date).count() + total_workflow_runs = db.session.query(WorkflowRun).where(WorkflowRun.created_at < cutoff_date).count() if total_workflow_runs == 0: _logger.info("No expired workflow run logs found") return @@ -49,7 +49,7 @@ def clean_workflow_runlogs_precise(): while True: workflow_runs = ( - db.session.query(WorkflowRun.id).filter(WorkflowRun.created_at < cutoff_date).limit(BATCH_SIZE).all() + db.session.query(WorkflowRun.id).where(WorkflowRun.created_at < cutoff_date).limit(BATCH_SIZE).all() ) if not workflow_runs: @@ -99,52 +99,52 @@ def _delete_batch_with_retry(workflow_run_ids: list[str], attempt_count: int) -> message_id_list = [msg.id for msg in message_data] conversation_id_list = list({msg.conversation_id for msg in message_data if msg.conversation_id}) if message_id_list: - db.session.query(AppAnnotationHitHistory).filter( + db.session.query(AppAnnotationHitHistory).where( AppAnnotationHitHistory.message_id.in_(message_id_list) ).delete(synchronize_session=False) - db.session.query(MessageAgentThought).filter( - MessageAgentThought.message_id.in_(message_id_list) - ).delete(synchronize_session=False) - - db.session.query(MessageChain).filter(MessageChain.message_id.in_(message_id_list)).delete( + db.session.query(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_id_list)).delete( synchronize_session=False ) - db.session.query(MessageFile).filter(MessageFile.message_id.in_(message_id_list)).delete( + db.session.query(MessageChain).where(MessageChain.message_id.in_(message_id_list)).delete( synchronize_session=False ) - db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id.in_(message_id_list)).delete( + db.session.query(MessageFile).where(MessageFile.message_id.in_(message_id_list)).delete( synchronize_session=False ) - db.session.query(MessageFeedback).filter(MessageFeedback.message_id.in_(message_id_list)).delete( + db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_id_list)).delete( synchronize_session=False ) - db.session.query(Message).filter(Message.workflow_run_id.in_(workflow_run_ids)).delete( + db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_id_list)).delete( synchronize_session=False ) - db.session.query(WorkflowAppLog).filter(WorkflowAppLog.workflow_run_id.in_(workflow_run_ids)).delete( + db.session.query(Message).where(Message.workflow_run_id.in_(workflow_run_ids)).delete( + synchronize_session=False + ) + + db.session.query(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(workflow_run_ids)).delete( synchronize_session=False ) - db.session.query(WorkflowNodeExecutionModel).filter( + db.session.query(WorkflowNodeExecutionModel).where( WorkflowNodeExecutionModel.workflow_run_id.in_(workflow_run_ids) ).delete(synchronize_session=False) if conversation_id_list: - db.session.query(ConversationVariable).filter( + db.session.query(ConversationVariable).where( ConversationVariable.conversation_id.in_(conversation_id_list) ).delete(synchronize_session=False) - db.session.query(Conversation).filter(Conversation.id.in_(conversation_id_list)).delete( + db.session.query(Conversation).where(Conversation.id.in_(conversation_id_list)).delete( synchronize_session=False ) - db.session.query(WorkflowRun).filter(WorkflowRun.id.in_(workflow_run_ids)).delete(synchronize_session=False) + db.session.query(WorkflowRun).where(WorkflowRun.id.in_(workflow_run_ids)).delete(synchronize_session=False) db.session.commit() return True diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index b7a047914..1a0fdfa42 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -293,7 +293,7 @@ class AppAnnotationService: annotation_ids_to_delete = [annotation.id for annotation, _ in annotations_to_delete] # Step 2: Bulk delete hit histories in a single query - db.session.query(AppAnnotationHitHistory).filter( + db.session.query(AppAnnotationHitHistory).where( AppAnnotationHitHistory.annotation_id.in_(annotation_ids_to_delete) ).delete(synchronize_session=False) @@ -307,7 +307,7 @@ class AppAnnotationService: # Step 4: Bulk delete annotations in a single query deleted_count = ( db.session.query(MessageAnnotation) - .filter(MessageAnnotation.id.in_(annotation_ids_to_delete)) + .where(MessageAnnotation.id.in_(annotation_ids_to_delete)) .delete(synchronize_session=False) ) @@ -505,9 +505,9 @@ class AppAnnotationService: db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() ) - annotations_query = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id) + annotations_query = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id) for annotation in annotations_query.yield_per(100): - annotation_hit_histories_query = db.session.query(AppAnnotationHitHistory).filter( + annotation_hit_histories_query = db.session.query(AppAnnotationHitHistory).where( AppAnnotationHitHistory.annotation_id == annotation.id ) for annotation_hit_history in annotation_hit_histories_query.yield_per(100): diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index 0ab5f398e..8816698af 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -471,7 +471,7 @@ class TestAnnotationService: # Verify annotation was deleted from extensions.ext_database import db - deleted_annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() assert deleted_annotation is None # Verify delete_annotation_index_task was called (when annotation setting exists) @@ -1175,7 +1175,7 @@ class TestAnnotationService: AppAnnotationService.delete_app_annotation(app.id, annotation_id) # Verify annotation was deleted - deleted_annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() assert deleted_annotation is None # Verify delete_annotation_index_task was called diff --git a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py index 38f532fd6..6cd8337ff 100644 --- a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py +++ b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py @@ -234,7 +234,7 @@ class TestAPIBasedExtensionService: # Verify extension was deleted from extensions.ext_database import db - deleted_extension = db.session.query(APIBasedExtension).filter(APIBasedExtension.id == extension_id).first() + deleted_extension = db.session.query(APIBasedExtension).where(APIBasedExtension.id == extension_id).first() assert deleted_extension is None def test_save_extension_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): diff --git a/api/tests/test_containers_integration_tests/services/test_message_service.py b/api/tests/test_containers_integration_tests/services/test_message_service.py index 25ba0d03e..ece6de6cd 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_message_service.py @@ -484,7 +484,7 @@ class TestMessageService: # Verify feedback was deleted from extensions.ext_database import db - deleted_feedback = db.session.query(MessageFeedback).filter(MessageFeedback.id == feedback.id).first() + deleted_feedback = db.session.query(MessageFeedback).where(MessageFeedback.id == feedback.id).first() assert deleted_feedback is None def test_create_feedback_no_rating_when_not_exists( diff --git a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py index a8a36b256..cb20238f0 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py @@ -469,6 +469,6 @@ class TestModelLoadBalancingService: # Verify inherit config was created in database inherit_configs = ( - db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.name == "__inherit__").all() + db.session.query(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__").all() ) assert len(inherit_configs) == 1