diff --git a/api/models/workflow.py b/api/models/workflow.py index 9cf6a0045..453a650f8 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -864,6 +864,19 @@ class WorkflowAppLog(Base): created_by_role = CreatorUserRole(self.created_by_role) return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None + def to_dict(self): + return { + "id": self.id, + "tenant_id": self.tenant_id, + "app_id": self.app_id, + "workflow_id": self.workflow_id, + "workflow_run_id": self.workflow_run_id, + "created_from": self.created_from, + "created_by_role": self.created_by_role, + "created_by": self.created_by, + "created_at": self.created_at, + } + class ConversationVariable(Base): __tablename__ = "workflow_conversation_variables" diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index d057a14af..b28afcaa4 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -13,7 +13,19 @@ from core.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from extensions.ext_storage import storage from models.account import Tenant -from models.model import App, Conversation, Message +from models.model import ( + App, + AppAnnotationHitHistory, + Conversation, + Message, + MessageAgentThought, + MessageAnnotation, + MessageChain, + MessageFeedback, + MessageFile, +) +from models.web import SavedMessage +from models.workflow import WorkflowAppLog from repositories.factory import DifyAPIRepositoryFactory from services.billing_service import BillingService @@ -21,6 +33,85 @@ logger = logging.getLogger(__name__) class ClearFreePlanTenantExpiredLogs: + @classmethod + def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]) -> None: + """ + Clean up message-related tables to avoid data redundancy. + This method cleans up tables that have foreign key relationships with Message. + + Args: + session: Database session, the same with the one in process_tenant method + tenant_id: Tenant ID for logging purposes + batch_message_ids: List of message IDs to clean up + """ + if not batch_message_ids: + return + + # Clean up each related table + related_tables = [ + (MessageFeedback, "message_feedbacks"), + (MessageFile, "message_files"), + (MessageAnnotation, "message_annotations"), + (MessageChain, "message_chains"), + (MessageAgentThought, "message_agent_thoughts"), + (AppAnnotationHitHistory, "app_annotation_hit_histories"), + (SavedMessage, "saved_messages"), + ] + + for model, table_name in related_tables: + # Query records related to expired messages + records = ( + session.query(model) + .filter( + model.message_id.in_(batch_message_ids), # type: ignore + ) + .all() + ) + + if len(records) == 0: + continue + + # Save records before deletion + record_ids = [record.id for record in records] + try: + record_data = [] + for record in records: + try: + if hasattr(record, "to_dict"): + record_data.append(record.to_dict()) + else: + # if record doesn't have to_dict method, we need to transform it to dict manually + record_dict = {} + for column in record.__table__.columns: + record_dict[column.name] = getattr(record, column.name) + record_data.append(record_dict) + except Exception: + logger.exception("Failed to transform %s record: %s", table_name, record.id) + continue + + if record_data: + storage.save( + f"free_plan_tenant_expired_logs/" + f"{tenant_id}/{table_name}/{datetime.datetime.now().strftime('%Y-%m-%d')}" + f"-{time.time()}.json", + json.dumps( + jsonable_encoder(record_data), + ).encode("utf-8"), + ) + except Exception: + logger.exception("Failed to save %s records", table_name) + + session.query(model).filter( + model.id.in_(record_ids), # type: ignore + ).delete(synchronize_session=False) + + click.echo( + click.style( + f"[{datetime.datetime.now()}] Processed {len(record_ids)} " + f"{table_name} records for tenant {tenant_id}" + ) + ) + @classmethod def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int): with flask_app.app_context(): @@ -58,6 +149,7 @@ class ClearFreePlanTenantExpiredLogs: Message.id.in_(message_ids), ).delete(synchronize_session=False) + cls._clear_message_related_tables(session, tenant_id, message_ids) session.commit() click.echo( @@ -199,6 +291,48 @@ class ClearFreePlanTenantExpiredLogs: if len(workflow_runs) < batch: break + while True: + with Session(db.engine).no_autoflush as session: + workflow_app_logs = ( + session.query(WorkflowAppLog) + .filter( + WorkflowAppLog.tenant_id == tenant_id, + WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days), + ) + .limit(batch) + .all() + ) + + if len(workflow_app_logs) == 0: + break + + # save workflow app logs + storage.save( + f"free_plan_tenant_expired_logs/" + f"{tenant_id}/workflow_app_logs/{datetime.datetime.now().strftime('%Y-%m-%d')}" + f"-{time.time()}.json", + json.dumps( + jsonable_encoder( + [workflow_app_log.to_dict() for workflow_app_log in workflow_app_logs], + ), + ).encode("utf-8"), + ) + + workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs] + + # delete workflow app logs + session.query(WorkflowAppLog).filter( + WorkflowAppLog.id.in_(workflow_app_log_ids), + ).delete(synchronize_session=False) + session.commit() + + click.echo( + click.style( + f"[{datetime.datetime.now()}] Processed {len(workflow_app_log_ids)}" + f" workflow app logs for tenant {tenant_id}" + ) + ) + @classmethod def process(cls, days: int, batch: int, tenant_ids: list[str]): """ diff --git a/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py new file mode 100644 index 000000000..dd2bc2181 --- /dev/null +++ b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py @@ -0,0 +1,168 @@ +import datetime +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy.orm import Session + +from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs + + +class TestClearFreePlanTenantExpiredLogs: + """Unit tests for ClearFreePlanTenantExpiredLogs._clear_message_related_tables method.""" + + @pytest.fixture + def mock_session(self): + """Create a mock database session.""" + session = Mock(spec=Session) + session.query.return_value.filter.return_value.all.return_value = [] + session.query.return_value.filter.return_value.delete.return_value = 0 + return session + + @pytest.fixture + def mock_storage(self): + """Create a mock storage object.""" + storage = Mock() + storage.save.return_value = None + return storage + + @pytest.fixture + def sample_message_ids(self): + """Sample message IDs for testing.""" + return ["msg-1", "msg-2", "msg-3"] + + @pytest.fixture + def sample_records(self): + """Sample records for testing.""" + records = [] + for i in range(3): + record = Mock() + record.id = f"record-{i}" + record.to_dict.return_value = { + "id": f"record-{i}", + "message_id": f"msg-{i}", + "created_at": datetime.datetime.now().isoformat(), + } + records.append(record) + return records + + def test_clear_message_related_tables_empty_message_ids(self, mock_session): + """Test that method returns early when message_ids is empty.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", []) + + # Should not call any database operations + mock_session.query.assert_not_called() + mock_storage.save.assert_not_called() + + def test_clear_message_related_tables_no_records_found(self, mock_session, sample_message_ids): + """Test when no related records are found.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + mock_session.query.return_value.filter.return_value.all.return_value = [] + + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should call query for each related table but find no records + assert mock_session.query.call_count > 0 + mock_storage.save.assert_not_called() + + def test_clear_message_related_tables_with_records_and_to_dict( + self, mock_session, sample_message_ids, sample_records + ): + """Test when records are found and have to_dict method.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + mock_session.query.return_value.filter.return_value.all.return_value = sample_records + + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should call to_dict on each record (called once per table, so 7 times total) + for record in sample_records: + assert record.to_dict.call_count == 7 + + # Should save backup data + assert mock_storage.save.call_count > 0 + + def test_clear_message_related_tables_with_records_no_to_dict(self, mock_session, sample_message_ids): + """Test when records are found but don't have to_dict method.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + # Create records without to_dict method + records = [] + for i in range(2): + record = Mock() + mock_table = Mock() + mock_id_column = Mock() + mock_id_column.name = "id" + mock_message_id_column = Mock() + mock_message_id_column.name = "message_id" + mock_table.columns = [mock_id_column, mock_message_id_column] + record.__table__ = mock_table + record.id = f"record-{i}" + record.message_id = f"msg-{i}" + del record.to_dict + records.append(record) + + # Mock records for first table only, empty for others + mock_session.query.return_value.filter.return_value.all.side_effect = [ + records, + [], + [], + [], + [], + [], + [], + ] + + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should save backup data even without to_dict + assert mock_storage.save.call_count > 0 + + def test_clear_message_related_tables_storage_error_continues( + self, mock_session, sample_message_ids, sample_records + ): + """Test that method continues even when storage.save fails.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + mock_storage.save.side_effect = Exception("Storage error") + + mock_session.query.return_value.filter.return_value.all.return_value = sample_records + + # Should not raise exception + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should still delete records even if backup fails + assert mock_session.query.return_value.filter.return_value.delete.called + + def test_clear_message_related_tables_serialization_error_continues(self, mock_session, sample_message_ids): + """Test that method continues even when record serialization fails.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + record = Mock() + record.id = "record-1" + record.to_dict.side_effect = Exception("Serialization error") + + mock_session.query.return_value.filter.return_value.all.return_value = [record] + + # Should not raise exception + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should still delete records even if serialization fails + assert mock_session.query.return_value.filter.return_value.delete.called + + def test_clear_message_related_tables_deletion_called(self, mock_session, sample_message_ids, sample_records): + """Test that deletion is called for found records.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + mock_session.query.return_value.filter.return_value.all.return_value = sample_records + + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should call delete for each table that has records + assert mock_session.query.return_value.filter.return_value.delete.called + + def test_clear_message_related_tables_logging_output( + self, mock_session, sample_message_ids, sample_records, capsys + ): + """Test that logging output is generated.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + mock_session.query.return_value.filter.return_value.all.return_value = sample_records + + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + pass