From e600070a61f350d19f87619a4314b47e75d49b14 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Wed, 13 Aug 2025 11:13:08 +0800 Subject: [PATCH] feat(api): auto-delete WorkflowDraftVariable when app is deleted (#23737) This commit introduces a background task that automatically deletes `WorkflowDraftVariable` records when their associated workflow apps are deleted. Additionally, it adds a new cleanup script `cleanup-orphaned-draft-variables` to remove existing orphaned draft variables from the database. --- api/commands.py | 136 ++++++++++ api/extensions/ext_commands.py | 2 + api/tasks/remove_app_and_related_data_task.py | 74 +++++- api/tests/integration_tests/tasks/__init__.py | 0 .../test_remove_app_and_related_data_task.py | 214 +++++++++++++++ api/tests/unit_tests/tasks/__init__.py | 0 .../test_remove_app_and_related_data_task.py | 243 ++++++++++++++++++ 7 files changed, 665 insertions(+), 4 deletions(-) create mode 100644 api/tests/integration_tests/tasks/__init__.py create mode 100644 api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py create mode 100644 api/tests/unit_tests/tasks/__init__.py create mode 100644 api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py diff --git a/api/commands.py b/api/commands.py index 8ee52ba71..6b38e34b9 100644 --- a/api/commands.py +++ b/api/commands.py @@ -36,6 +36,7 @@ from services.account_service import AccountService, RegisterService, TenantServ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs from services.plugin.data_migration import PluginDataMigration from services.plugin.plugin_migration import PluginMigration +from tasks.remove_app_and_related_data_task import delete_draft_variables_batch @click.command("reset-password", help="Reset the account password.") @@ -1202,3 +1203,138 @@ def setup_system_tool_oauth_client(provider, client_params): db.session.add(oauth_client) db.session.commit() click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green")) + + +def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]: + """ + Find draft variables that reference non-existent apps. + + Args: + batch_size: Maximum number of orphaned app IDs to return + + Returns: + List of app IDs that have draft variables but don't exist in the apps table + """ + query = """ + SELECT DISTINCT wdv.app_id + FROM workflow_draft_variables AS wdv + WHERE NOT EXISTS( + SELECT 1 FROM apps WHERE apps.id = wdv.app_id + ) + LIMIT :batch_size + """ + + with db.engine.connect() as conn: + result = conn.execute(sa.text(query), {"batch_size": batch_size}) + return [row[0] for row in result] + + +def _count_orphaned_draft_variables() -> dict[str, Any]: + """ + Count orphaned draft variables by app. + + Returns: + Dictionary with statistics about orphaned variables + """ + query = """ + SELECT + wdv.app_id, + COUNT(*) as variable_count + FROM workflow_draft_variables AS wdv + WHERE NOT EXISTS( + SELECT 1 FROM apps WHERE apps.id = wdv.app_id + ) + GROUP BY wdv.app_id + ORDER BY variable_count DESC + """ + + with db.engine.connect() as conn: + result = conn.execute(sa.text(query)) + orphaned_by_app = {row[0]: row[1] for row in result} + + total_orphaned = sum(orphaned_by_app.values()) + app_count = len(orphaned_by_app) + + return { + "total_orphaned_variables": total_orphaned, + "orphaned_app_count": app_count, + "orphaned_by_app": orphaned_by_app, + } + + +@click.command() +@click.option("--dry-run", is_flag=True, help="Show what would be deleted without actually deleting") +@click.option("--batch-size", default=1000, help="Number of records to process per batch (default 1000)") +@click.option("--max-apps", default=None, type=int, help="Maximum number of apps to process (default: no limit)") +@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.") +def cleanup_orphaned_draft_variables( + dry_run: bool, + batch_size: int, + max_apps: int | None, + force: bool = False, +): + """ + Clean up orphaned draft variables from the database. + + This script finds and removes draft variables that belong to apps + that no longer exist in the database. + """ + logger = logging.getLogger(__name__) + + # Get statistics + stats = _count_orphaned_draft_variables() + + logger.info("Found %s orphaned draft variables", stats["total_orphaned_variables"]) + logger.info("Across %s non-existent apps", stats["orphaned_app_count"]) + + if stats["total_orphaned_variables"] == 0: + logger.info("No orphaned draft variables found. Exiting.") + return + + if dry_run: + logger.info("DRY RUN: Would delete the following:") + for app_id, count in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1], reverse=True)[ + :10 + ]: # Show top 10 + logger.info(" App %s: %s variables", app_id, count) + if len(stats["orphaned_by_app"]) > 10: + logger.info(" ... and %s more apps", len(stats["orphaned_by_app"]) - 10) + return + + # Confirm deletion + if not force: + click.confirm( + f"Are you sure you want to delete {stats['total_orphaned_variables']} " + f"orphaned draft variables from {stats['orphaned_app_count']} apps?", + abort=True, + ) + + total_deleted = 0 + processed_apps = 0 + + while True: + if max_apps and processed_apps >= max_apps: + logger.info("Reached maximum app limit (%s). Stopping.", max_apps) + break + + orphaned_app_ids = _find_orphaned_draft_variables(batch_size=10) + if not orphaned_app_ids: + logger.info("No more orphaned draft variables found.") + break + + for app_id in orphaned_app_ids: + if max_apps and processed_apps >= max_apps: + break + + try: + deleted_count = delete_draft_variables_batch(app_id, batch_size) + total_deleted += deleted_count + processed_apps += 1 + + logger.info("Deleted %s variables for app %s", deleted_count, app_id) + + except Exception: + logger.exception("Error processing app %s", app_id) + continue + + logger.info("Cleanup completed. Total deleted: %s variables across %s apps", total_deleted, processed_apps) diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index 600e336c1..8904ff7a9 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -4,6 +4,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): from commands import ( add_qdrant_index, + cleanup_orphaned_draft_variables, clear_free_plan_tenant_expired_logs, clear_orphaned_file_records, convert_to_agent_apps, @@ -42,6 +43,7 @@ def init_app(app: DifyApp): clear_orphaned_file_records, remove_orphaned_files_on_storage, setup_system_tool_oauth_client, + cleanup_orphaned_draft_variables, ] for cmd in cmds_to_register: app.cli.add_command(cmd) diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 929b60e52..828c52044 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -33,7 +33,11 @@ from models import ( ) from models.tools import WorkflowToolProvider from models.web import PinnedConversation, SavedMessage -from models.workflow import ConversationVariable, Workflow, WorkflowAppLog +from models.workflow import ( + ConversationVariable, + Workflow, + WorkflowAppLog, +) from repositories.factory import DifyAPIRepositoryFactory @@ -62,6 +66,7 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): _delete_end_users(tenant_id, app_id) _delete_trace_app_configs(tenant_id, app_id) _delete_conversation_variables(app_id=app_id) + _delete_draft_variables(app_id) end_at = time.perf_counter() logging.info(click.style(f"App and related data deleted: {app_id} latency: {end_at - start_at}", fg="green")) @@ -91,7 +96,12 @@ def _delete_app_site(tenant_id: str, app_id: str): def del_site(site_id: str): db.session.query(Site).where(Site.id == site_id).delete(synchronize_session=False) - _delete_records("""select id from sites where app_id=:app_id limit 1000""", {"app_id": app_id}, del_site, "site") + _delete_records( + """select id from sites where app_id=:app_id limit 1000""", + {"app_id": app_id}, + del_site, + "site", + ) def _delete_app_mcp_servers(tenant_id: str, app_id: str): @@ -111,7 +121,10 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str): db.session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False) _delete_records( - """select id from api_tokens where app_id=:app_id limit 1000""", {"app_id": app_id}, del_api_token, "api token" + """select id from api_tokens where app_id=:app_id limit 1000""", + {"app_id": app_id}, + del_api_token, + "api token", ) @@ -273,7 +286,10 @@ def _delete_app_messages(tenant_id: str, app_id: str): db.session.query(Message).where(Message.id == message_id).delete() _delete_records( - """select id from messages where app_id=:app_id limit 1000""", {"app_id": app_id}, del_message, "message" + """select id from messages where app_id=:app_id limit 1000""", + {"app_id": app_id}, + del_message, + "message", ) @@ -329,6 +345,56 @@ def _delete_trace_app_configs(tenant_id: str, app_id: str): ) +def _delete_draft_variables(app_id: str): + """Delete all workflow draft variables for an app in batches.""" + return delete_draft_variables_batch(app_id, batch_size=1000) + + +def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int: + """ + Delete draft variables for an app in batches. + + Args: + app_id: The ID of the app whose draft variables should be deleted + batch_size: Number of records to delete per batch + + Returns: + Total number of records deleted + """ + if batch_size <= 0: + raise ValueError("batch_size must be positive") + + total_deleted = 0 + + while True: + with db.engine.begin() as conn: + # Get a batch of draft variable IDs + query_sql = """ + SELECT id FROM workflow_draft_variables + WHERE app_id = :app_id + LIMIT :batch_size + """ + result = conn.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size}) + + draft_var_ids = [row[0] for row in result] + if not draft_var_ids: + break + + # Delete the batch + delete_sql = """ + DELETE FROM workflow_draft_variables + WHERE id IN :ids + """ + deleted_result = conn.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)}) + batch_deleted = deleted_result.rowcount + total_deleted += batch_deleted + + logging.info(click.style(f"Deleted {batch_deleted} draft variables (batch) for app {app_id}", fg="green")) + + logging.info(click.style(f"Deleted {total_deleted} total draft variables for app {app_id}", fg="green")) + return total_deleted + + def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None: while True: with db.engine.begin() as conn: diff --git a/api/tests/integration_tests/tasks/__init__.py b/api/tests/integration_tests/tasks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py new file mode 100644 index 000000000..2f7fc60ad --- /dev/null +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -0,0 +1,214 @@ +import uuid + +import pytest +from sqlalchemy import delete + +from core.variables.segments import StringSegment +from models import Tenant, db +from models.model import App +from models.workflow import WorkflowDraftVariable +from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch + + +@pytest.fixture +def app_and_tenant(flask_req_ctx): + tenant_id = uuid.uuid4() + tenant = Tenant( + id=tenant_id, + name="test_tenant", + ) + db.session.add(tenant) + + app = App( + tenant_id=tenant_id, # Now tenant.id will have a value + name=f"Test App for tenant {tenant.id}", + mode="workflow", + enable_site=True, + enable_api=True, + ) + db.session.add(app) + db.session.flush() + yield (tenant, app) + + # Cleanup with proper error handling + db.session.delete(app) + db.session.delete(tenant) + + +class TestDeleteDraftVariablesIntegration: + @pytest.fixture + def setup_test_data(self, app_and_tenant): + """Create test data with apps and draft variables.""" + tenant, app = app_and_tenant + + # Create a second app for testing + app2 = App( + tenant_id=tenant.id, + name="Test App 2", + mode="workflow", + enable_site=True, + enable_api=True, + ) + db.session.add(app2) + db.session.commit() + + # Create draft variables for both apps + variables_app1 = [] + variables_app2 = [] + + for i in range(5): + var1 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id=f"node_{i}", + name=f"var_{i}", + value=StringSegment(value="test_value"), + node_execution_id=str(uuid.uuid4()), + ) + db.session.add(var1) + variables_app1.append(var1) + + var2 = WorkflowDraftVariable.new_node_variable( + app_id=app2.id, + node_id=f"node_{i}", + name=f"var_{i}", + value=StringSegment(value="test_value"), + node_execution_id=str(uuid.uuid4()), + ) + db.session.add(var2) + variables_app2.append(var2) + + # Commit all the variables to the database + db.session.commit() + + yield { + "app1": app, + "app2": app2, + "tenant": tenant, + "variables_app1": variables_app1, + "variables_app2": variables_app2, + } + + # Cleanup - refresh session and check if objects still exist + db.session.rollback() # Clear any pending changes + + # Clean up remaining variables + cleanup_query = ( + delete(WorkflowDraftVariable) + .where( + WorkflowDraftVariable.app_id.in_([app.id, app2.id]), + ) + .execution_options(synchronize_session=False) + ) + db.session.execute(cleanup_query) + + # Clean up app2 + app2_obj = db.session.get(App, app2.id) + if app2_obj: + db.session.delete(app2_obj) + + db.session.commit() + + def test_delete_draft_variables_batch_removes_correct_variables(self, setup_test_data): + """Test that batch deletion only removes variables for the specified app.""" + data = setup_test_data + app1_id = data["app1"].id + app2_id = data["app2"].id + + # Verify initial state + app1_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + app2_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count() + assert app1_vars_before == 5 + assert app2_vars_before == 5 + + # Delete app1 variables + deleted_count = delete_draft_variables_batch(app1_id, batch_size=10) + + # Verify results + assert deleted_count == 5 + + app1_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + app2_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count() + + assert app1_vars_after == 0 # All app1 variables deleted + assert app2_vars_after == 5 # App2 variables unchanged + + def test_delete_draft_variables_batch_with_small_batch_size(self, setup_test_data): + """Test batch deletion with small batch size processes all records.""" + data = setup_test_data + app1_id = data["app1"].id + + # Use small batch size to force multiple batches + deleted_count = delete_draft_variables_batch(app1_id, batch_size=2) + + assert deleted_count == 5 + + # Verify all variables are deleted + remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + assert remaining_vars == 0 + + def test_delete_draft_variables_batch_nonexistent_app(self, setup_test_data): + """Test that deleting variables for nonexistent app returns 0.""" + nonexistent_app_id = str(uuid.uuid4()) # Use a valid UUID format + + deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=100) + + assert deleted_count == 0 + + def test_delete_draft_variables_wrapper_function(self, setup_test_data): + """Test that _delete_draft_variables wrapper function works correctly.""" + data = setup_test_data + app1_id = data["app1"].id + + # Verify initial state + vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + assert vars_before == 5 + + # Call wrapper function + deleted_count = _delete_draft_variables(app1_id) + + # Verify results + assert deleted_count == 5 + + vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + assert vars_after == 0 + + def test_batch_deletion_handles_large_dataset(self, app_and_tenant): + """Test batch deletion with larger dataset to verify batching logic.""" + tenant, app = app_and_tenant + + # Create many draft variables + variables = [] + for i in range(25): + var = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id=f"node_{i}", + name=f"var_{i}", + value=StringSegment(value="test_value"), + node_execution_id=str(uuid.uuid4()), + ) + db.session.add(var) + variables.append(var) + variable_ids = [i.id for i in variables] + + # Commit the variables to the database + db.session.commit() + + try: + # Use small batch size to force multiple batches + deleted_count = delete_draft_variables_batch(app.id, batch_size=8) + + assert deleted_count == 25 + + # Verify all variables are deleted + remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count() + assert remaining_vars == 0 + + finally: + query = ( + delete(WorkflowDraftVariable) + .where( + WorkflowDraftVariable.id.in_(variable_ids), + ) + .execution_options(synchronize_session=False) + ) + db.session.execute(query) diff --git a/api/tests/unit_tests/tasks/__init__.py b/api/tests/unit_tests/tasks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py new file mode 100644 index 000000000..d8003570b --- /dev/null +++ b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py @@ -0,0 +1,243 @@ +from unittest.mock import ANY, MagicMock, call, patch + +import pytest +import sqlalchemy as sa + +from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch + + +class TestDeleteDraftVariablesBatch: + @patch("tasks.remove_app_and_related_data_task.db") + def test_delete_draft_variables_batch_success(self, mock_db): + """Test successful deletion of draft variables in batches.""" + app_id = "test-app-id" + batch_size = 100 + + # Mock database connection and engine + mock_conn = MagicMock() + mock_engine = MagicMock() + mock_db.engine = mock_engine + # Properly mock the context manager + mock_context_manager = MagicMock() + mock_context_manager.__enter__.return_value = mock_conn + mock_context_manager.__exit__.return_value = None + mock_engine.begin.return_value = mock_context_manager + + # Mock two batches of results, then empty + batch1_ids = [f"var-{i}" for i in range(100)] + batch2_ids = [f"var-{i}" for i in range(100, 150)] + + # Setup side effects for execute calls in the correct order: + # 1. SELECT (returns batch1_ids) + # 2. DELETE (returns result with rowcount=100) + # 3. SELECT (returns batch2_ids) + # 4. DELETE (returns result with rowcount=50) + # 5. SELECT (returns empty, ends loop) + + # Create mock results with actual integer rowcount attributes + class MockResult: + def __init__(self, rowcount): + self.rowcount = rowcount + + # First SELECT result + select_result1 = MagicMock() + select_result1.__iter__.return_value = iter([(id_,) for id_ in batch1_ids]) + + # First DELETE result + delete_result1 = MockResult(rowcount=100) + + # Second SELECT result + select_result2 = MagicMock() + select_result2.__iter__.return_value = iter([(id_,) for id_ in batch2_ids]) + + # Second DELETE result + delete_result2 = MockResult(rowcount=50) + + # Third SELECT result (empty, ends loop) + select_result3 = MagicMock() + select_result3.__iter__.return_value = iter([]) + + # Configure side effects in the correct order + mock_conn.execute.side_effect = [ + select_result1, # First SELECT + delete_result1, # First DELETE + select_result2, # Second SELECT + delete_result2, # Second DELETE + select_result3, # Third SELECT (empty) + ] + + # Execute the function + result = delete_draft_variables_batch(app_id, batch_size) + + # Verify the result + assert result == 150 + + # Verify database calls + assert mock_conn.execute.call_count == 5 # 3 selects + 2 deletes + + # Verify the expected calls in order: + # 1. SELECT, 2. DELETE, 3. SELECT, 4. DELETE, 5. SELECT + expected_calls = [ + # First SELECT + call( + sa.text(""" + SELECT id FROM workflow_draft_variables + WHERE app_id = :app_id + LIMIT :batch_size + """), + {"app_id": app_id, "batch_size": batch_size}, + ), + # First DELETE + call( + sa.text(""" + DELETE FROM workflow_draft_variables + WHERE id IN :ids + """), + {"ids": tuple(batch1_ids)}, + ), + # Second SELECT + call( + sa.text(""" + SELECT id FROM workflow_draft_variables + WHERE app_id = :app_id + LIMIT :batch_size + """), + {"app_id": app_id, "batch_size": batch_size}, + ), + # Second DELETE + call( + sa.text(""" + DELETE FROM workflow_draft_variables + WHERE id IN :ids + """), + {"ids": tuple(batch2_ids)}, + ), + # Third SELECT (empty result) + call( + sa.text(""" + SELECT id FROM workflow_draft_variables + WHERE app_id = :app_id + LIMIT :batch_size + """), + {"app_id": app_id, "batch_size": batch_size}, + ), + ] + + # Check that all calls were made correctly + actual_calls = mock_conn.execute.call_args_list + assert len(actual_calls) == len(expected_calls) + + # Simplified verification - just check that the right number of calls were made + # and that the SQL queries contain the expected patterns + for i, actual_call in enumerate(actual_calls): + if i % 2 == 0: # SELECT calls (even indices: 0, 2, 4) + # Verify it's a SELECT query + sql_text = str(actual_call[0][0]) + assert "SELECT id FROM workflow_draft_variables" in sql_text + assert "WHERE app_id = :app_id" in sql_text + assert "LIMIT :batch_size" in sql_text + else: # DELETE calls (odd indices: 1, 3) + # Verify it's a DELETE query + sql_text = str(actual_call[0][0]) + assert "DELETE FROM workflow_draft_variables" in sql_text + assert "WHERE id IN :ids" in sql_text + + @patch("tasks.remove_app_and_related_data_task.db") + def test_delete_draft_variables_batch_empty_result(self, mock_db): + """Test deletion when no draft variables exist for the app.""" + app_id = "nonexistent-app-id" + batch_size = 1000 + + # Mock database connection + mock_conn = MagicMock() + mock_engine = MagicMock() + mock_db.engine = mock_engine + # Properly mock the context manager + mock_context_manager = MagicMock() + mock_context_manager.__enter__.return_value = mock_conn + mock_context_manager.__exit__.return_value = None + mock_engine.begin.return_value = mock_context_manager + + # Mock empty result + empty_result = MagicMock() + empty_result.__iter__.return_value = iter([]) + mock_conn.execute.return_value = empty_result + + result = delete_draft_variables_batch(app_id, batch_size) + + assert result == 0 + assert mock_conn.execute.call_count == 1 # Only one select query + + def test_delete_draft_variables_batch_invalid_batch_size(self): + """Test that invalid batch size raises ValueError.""" + app_id = "test-app-id" + + with pytest.raises(ValueError, match="batch_size must be positive"): + delete_draft_variables_batch(app_id, -1) + + with pytest.raises(ValueError, match="batch_size must be positive"): + delete_draft_variables_batch(app_id, 0) + + @patch("tasks.remove_app_and_related_data_task.db") + @patch("tasks.remove_app_and_related_data_task.logging") + def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_db): + """Test that batch deletion logs progress correctly.""" + app_id = "test-app-id" + batch_size = 50 + + # Mock database + mock_conn = MagicMock() + mock_engine = MagicMock() + mock_db.engine = mock_engine + # Properly mock the context manager + mock_context_manager = MagicMock() + mock_context_manager.__enter__.return_value = mock_conn + mock_context_manager.__exit__.return_value = None + mock_engine.begin.return_value = mock_context_manager + + # Mock one batch then empty + batch_ids = [f"var-{i}" for i in range(30)] + # Create properly configured mocks + select_result = MagicMock() + select_result.__iter__.return_value = iter([(id_,) for id_ in batch_ids]) + + # Create simple object with rowcount attribute + class MockResult: + def __init__(self, rowcount): + self.rowcount = rowcount + + delete_result = MockResult(rowcount=30) + + empty_result = MagicMock() + empty_result.__iter__.return_value = iter([]) + + mock_conn.execute.side_effect = [ + # Select query result + select_result, + # Delete query result + delete_result, + # Empty select result (end condition) + empty_result, + ] + + result = delete_draft_variables_batch(app_id, batch_size) + + assert result == 30 + + # Verify logging calls + assert mock_logging.info.call_count == 2 + mock_logging.info.assert_any_call( + ANY # click.style call + ) + + @patch("tasks.remove_app_and_related_data_task.delete_draft_variables_batch") + def test_delete_draft_variables_calls_batch_function(self, mock_batch_delete): + """Test that _delete_draft_variables calls the batch function correctly.""" + app_id = "test-app-id" + expected_return = 42 + mock_batch_delete.return_value = expected_return + + result = _delete_draft_variables(app_id) + + assert result == expected_return + mock_batch_delete.assert_called_once_with(app_id, batch_size=1000)