import datetime import json import logging import time from concurrent.futures import ThreadPoolExecutor import click from flask import Flask, current_app from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from extensions.ext_storage import storage from models.account import Tenant from models.model import ( App, AppAnnotationHitHistory, Conversation, Message, MessageAgentThought, MessageAnnotation, MessageChain, MessageFeedback, MessageFile, ) from models.web import SavedMessage from models.workflow import WorkflowAppLog from repositories.factory import DifyAPIRepositoryFactory from services.billing_service import BillingService logger = logging.getLogger(__name__) class ClearFreePlanTenantExpiredLogs: @classmethod def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]) -> None: """ Clean up message-related tables to avoid data redundancy. This method cleans up tables that have foreign key relationships with Message. Args: session: Database session, the same with the one in process_tenant method tenant_id: Tenant ID for logging purposes batch_message_ids: List of message IDs to clean up """ if not batch_message_ids: return # Clean up each related table related_tables = [ (MessageFeedback, "message_feedbacks"), (MessageFile, "message_files"), (MessageAnnotation, "message_annotations"), (MessageChain, "message_chains"), (MessageAgentThought, "message_agent_thoughts"), (AppAnnotationHitHistory, "app_annotation_hit_histories"), (SavedMessage, "saved_messages"), ] for model, table_name in related_tables: # Query records related to expired messages records = ( session.query(model) .where( model.message_id.in_(batch_message_ids), # type: ignore ) .all() ) if len(records) == 0: continue # Save records before deletion record_ids = [record.id for record in records] try: record_data = [] for record in records: try: if hasattr(record, "to_dict"): record_data.append(record.to_dict()) else: # if record doesn't have to_dict method, we need to transform it to dict manually record_dict = {} for column in record.__table__.columns: record_dict[column.name] = getattr(record, column.name) record_data.append(record_dict) except Exception: logger.exception("Failed to transform %s record: %s", table_name, record.id) continue if record_data: storage.save( f"free_plan_tenant_expired_logs/" f"{tenant_id}/{table_name}/{datetime.datetime.now().strftime('%Y-%m-%d')}" f"-{time.time()}.json", json.dumps( jsonable_encoder(record_data), ).encode("utf-8"), ) except Exception: logger.exception("Failed to save %s records", table_name) session.query(model).where( model.id.in_(record_ids), # type: ignore ).delete(synchronize_session=False) click.echo( click.style( f"[{datetime.datetime.now()}] Processed {len(record_ids)} " f"{table_name} records for tenant {tenant_id}" ) ) @classmethod def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int): with flask_app.app_context(): apps = db.session.query(App).where(App.tenant_id == tenant_id).all() app_ids = [app.id for app in apps] while True: with Session(db.engine).no_autoflush as session: messages = ( session.query(Message) .where( Message.app_id.in_(app_ids), Message.created_at < datetime.datetime.now() - datetime.timedelta(days=days), ) .limit(batch) .all() ) if len(messages) == 0: break storage.save( f"free_plan_tenant_expired_logs/" f"{tenant_id}/messages/{datetime.datetime.now().strftime('%Y-%m-%d')}" f"-{time.time()}.json", json.dumps( jsonable_encoder( [message.to_dict() for message in messages], ), ).encode("utf-8"), ) message_ids = [message.id for message in messages] # delete messages session.query(Message).where( Message.id.in_(message_ids), ).delete(synchronize_session=False) cls._clear_message_related_tables(session, tenant_id, message_ids) session.commit() click.echo( click.style( f"[{datetime.datetime.now()}] Processed {len(message_ids)} messages for tenant {tenant_id} " ) ) while True: with Session(db.engine).no_autoflush as session: conversations = ( session.query(Conversation) .where( Conversation.app_id.in_(app_ids), Conversation.updated_at < datetime.datetime.now() - datetime.timedelta(days=days), ) .limit(batch) .all() ) if len(conversations) == 0: break storage.save( f"free_plan_tenant_expired_logs/" f"{tenant_id}/conversations/{datetime.datetime.now().strftime('%Y-%m-%d')}" f"-{time.time()}.json", json.dumps( jsonable_encoder( [conversation.to_dict() for conversation in conversations], ), ).encode("utf-8"), ) conversation_ids = [conversation.id for conversation in conversations] session.query(Conversation).where( Conversation.id.in_(conversation_ids), ).delete(synchronize_session=False) session.commit() click.echo( click.style( f"[{datetime.datetime.now()}] Processed {len(conversation_ids)}" f" conversations for tenant {tenant_id}" ) ) # Process expired workflow node executions with backup session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker) before_date = datetime.datetime.now() - datetime.timedelta(days=days) total_deleted = 0 while True: # Get a batch of expired executions for backup workflow_node_executions = node_execution_repo.get_expired_executions_batch( tenant_id=tenant_id, before_date=before_date, batch_size=batch, ) if len(workflow_node_executions) == 0: break # Save workflow node executions to storage storage.save( f"free_plan_tenant_expired_logs/" f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}" f"-{time.time()}.json", json.dumps( jsonable_encoder(workflow_node_executions), ).encode("utf-8"), ) # Extract IDs for deletion workflow_node_execution_ids = [ workflow_node_execution.id for workflow_node_execution in workflow_node_executions ] # Delete the backed up executions deleted_count = node_execution_repo.delete_executions_by_ids(workflow_node_execution_ids) total_deleted += deleted_count click.echo( click.style( f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}" f" workflow node executions for tenant {tenant_id}" ) ) # If we got fewer than the batch size, we're done if len(workflow_node_executions) < batch: break # Process expired workflow runs with backup session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) before_date = datetime.datetime.now() - datetime.timedelta(days=days) total_deleted = 0 while True: # Get a batch of expired workflow runs for backup workflow_runs = workflow_run_repo.get_expired_runs_batch( tenant_id=tenant_id, before_date=before_date, batch_size=batch, ) if len(workflow_runs) == 0: break # Save workflow runs to storage storage.save( f"free_plan_tenant_expired_logs/" f"{tenant_id}/workflow_runs/{datetime.datetime.now().strftime('%Y-%m-%d')}" f"-{time.time()}.json", json.dumps( jsonable_encoder( [workflow_run.to_dict() for workflow_run in workflow_runs], ), ).encode("utf-8"), ) # Extract IDs for deletion workflow_run_ids = [workflow_run.id for workflow_run in workflow_runs] # Delete the backed up workflow runs deleted_count = workflow_run_repo.delete_runs_by_ids(workflow_run_ids) total_deleted += deleted_count click.echo( click.style( f"[{datetime.datetime.now()}] Processed {len(workflow_run_ids)}" f" workflow runs for tenant {tenant_id}" ) ) # If we got fewer than the batch size, we're done if len(workflow_runs) < batch: break while True: with Session(db.engine).no_autoflush as session: workflow_app_logs = ( session.query(WorkflowAppLog) .where( WorkflowAppLog.tenant_id == tenant_id, WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days), ) .limit(batch) .all() ) if len(workflow_app_logs) == 0: break # save workflow app logs storage.save( f"free_plan_tenant_expired_logs/" f"{tenant_id}/workflow_app_logs/{datetime.datetime.now().strftime('%Y-%m-%d')}" f"-{time.time()}.json", json.dumps( jsonable_encoder( [workflow_app_log.to_dict() for workflow_app_log in workflow_app_logs], ), ).encode("utf-8"), ) workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs] # delete workflow app logs session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete( synchronize_session=False ) session.commit() click.echo( click.style( f"[{datetime.datetime.now()}] Processed {len(workflow_app_log_ids)}" f" workflow app logs for tenant {tenant_id}" ) ) @classmethod def process(cls, days: int, batch: int, tenant_ids: list[str]): """ Clear free plan tenant expired logs. """ click.echo(click.style("Clearing free plan tenant expired logs", fg="white")) ended_at = datetime.datetime.now() started_at = datetime.datetime(2023, 4, 3, 8, 59, 24) current_time = started_at with Session(db.engine) as session: total_tenant_count = session.query(Tenant.id).count() click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white")) handled_tenant_count = 0 thread_pool = ThreadPoolExecutor(max_workers=10) def process_tenant(flask_app: Flask, tenant_id: str) -> None: try: if ( not dify_config.BILLING_ENABLED or BillingService.get_info(tenant_id)["subscription"]["plan"] == "sandbox" ): # only process sandbox tenant cls.process_tenant(flask_app, tenant_id, days, batch) except Exception: logger.exception("Failed to process tenant %s", tenant_id) finally: nonlocal handled_tenant_count handled_tenant_count += 1 if handled_tenant_count % 100 == 0: click.echo( click.style( f"[{datetime.datetime.now()}] " f"Processed {handled_tenant_count} tenants " f"({(handled_tenant_count / total_tenant_count) * 100:.1f}%), " f"{handled_tenant_count}/{total_tenant_count}", fg="green", ) ) futures = [] if tenant_ids: for tenant_id in tenant_ids: futures.append( thread_pool.submit( process_tenant, current_app._get_current_object(), # type: ignore[attr-defined] tenant_id, ) ) else: while current_time < ended_at: click.echo( click.style(f"Current time: {current_time}, Started at: {datetime.datetime.now()}", fg="white") ) # Initial interval of 1 day, will be dynamically adjusted based on tenant count interval = datetime.timedelta(days=1) # Process tenants in this batch with Session(db.engine) as session: # Calculate tenant count in next batch with current interval # Try different intervals until we find one with a reasonable tenant count test_intervals = [ datetime.timedelta(days=1), datetime.timedelta(hours=12), datetime.timedelta(hours=6), datetime.timedelta(hours=3), datetime.timedelta(hours=1), ] for test_interval in test_intervals: tenant_count = ( session.query(Tenant.id) .where(Tenant.created_at.between(current_time, current_time + test_interval)) .count() ) if tenant_count <= 100: interval = test_interval break else: # If all intervals have too many tenants, use minimum interval interval = datetime.timedelta(hours=1) # Adjust interval to target ~100 tenants per batch if tenant_count > 0: # Scale interval based on ratio to target count interval = min( datetime.timedelta(days=1), # Max 1 day max( datetime.timedelta(hours=1), # Min 1 hour interval * (100 / tenant_count), # Scale to target 100 ), ) batch_end = min(current_time + interval, ended_at) rs = ( session.query(Tenant.id) .where(Tenant.created_at.between(current_time, batch_end)) .order_by(Tenant.created_at) ) tenants = [] for row in rs: tenant_id = str(row.id) try: tenants.append(tenant_id) except Exception: logger.exception("Failed to process tenant %s", tenant_id) continue futures.append( thread_pool.submit( process_tenant, current_app._get_current_object(), # type: ignore[attr-defined] tenant_id, ) ) current_time = batch_end # wait for all threads to finish for future in futures: future.result()