diff --git a/api/.env.example b/api/.env.example index a7ea6cf93..eab017a62 100644 --- a/api/.env.example +++ b/api/.env.example @@ -449,6 +449,19 @@ MAX_VARIABLE_SIZE=204800 # hybrid: Save new data to object storage, read from both object storage and RDBMS WORKFLOW_NODE_EXECUTION_STORAGE=rdbms +# Repository configuration +# Core workflow execution repository implementation +CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository + +# Core workflow node execution repository implementation +CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository + +# API workflow node execution repository implementation +API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository + +# API workflow run repository implementation +API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository + # App configuration APP_MAX_EXECUTION_TIME=1200 APP_MAX_ACTIVE_REQUESTS=0 diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 963fcbedf..f6a8b037c 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -537,6 +537,33 @@ class WorkflowNodeExecutionConfig(BaseSettings): ) +class RepositoryConfig(BaseSettings): + """ + Configuration for repository implementations + """ + + CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field( + description="Repository implementation for WorkflowExecution. Specify as a module path", + default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository", + ) + + CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field( + description="Repository implementation for WorkflowNodeExecution. Specify as a module path", + default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository", + ) + + API_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field( + description="Service-layer repository implementation for WorkflowNodeExecutionModel operations. " + "Specify as a module path", + default="repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository", + ) + + API_WORKFLOW_RUN_REPOSITORY: str = Field( + description="Service-layer repository implementation for WorkflowRun operations. Specify as a module path", + default="repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository", + ) + + class AuthConfig(BaseSettings): """ Configuration for authentication and OAuth @@ -903,6 +930,7 @@ class FeatureConfig( MultiModalTransferConfig, PositionConfig, RagEtlConfig, + RepositoryConfig, SecurityConfig, ToolConfig, UpdateConfig, diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 03b60610a..3322350e2 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -35,8 +35,6 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[ raise AppNotFoundError() app_mode = AppMode.value_of(app_model.mode) - if app_mode == AppMode.CHANNEL: - raise AppNotFoundError() if mode is not None: if isinstance(mode, list): diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index efb4acc5f..ac2ebf2b0 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -3,7 +3,7 @@ import logging from dateutil.parser import isoparse from flask_restful import Resource, fields, marshal_with, reqparse from flask_restful.inputs import int_range -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import InternalServerError from controllers.service_api import api @@ -30,7 +30,7 @@ from fields.workflow_app_log_fields import workflow_app_log_pagination_fields from libs import helper from libs.helper import TimestampField from models.model import App, AppMode, EndUser -from models.workflow import WorkflowRun +from repositories.factory import DifyAPIRepositoryFactory from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError from services.workflow_app_service import WorkflowAppService @@ -63,7 +63,15 @@ class WorkflowRunDetailApi(Resource): if app_mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]: raise NotWorkflowAppError() - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() + # Use repository to get workflow run + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + + workflow_run = workflow_run_repo.get_workflow_run_by_id( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + run_id=workflow_run_id, + ) return workflow_run diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 0d304de97..28bf4a9a2 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -3,6 +3,8 @@ import logging import uuid from typing import Optional, Union, cast +from sqlalchemy import select + from core.agent.entities import AgentEntity, AgentToolEntity from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig @@ -417,12 +419,15 @@ class BaseAgentRunner(AppRunner): if isinstance(prompt_message, SystemPromptMessage): result.append(prompt_message) - messages: list[Message] = ( - db.session.query(Message) - .filter( - Message.conversation_id == self.message.conversation_id, + messages = ( + ( + db.session.execute( + select(Message) + .where(Message.conversation_id == self.message.conversation_id) + .order_by(Message.created_at.desc()) + ) ) - .order_by(Message.created_at.desc()) + .scalars() .all() ) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 7877408ce..4b8f5ebe2 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -25,8 +25,7 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.repositories.draft_variable_repository import ( DraftVariableSaverFactory, ) @@ -183,14 +182,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING else: workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=workflow_triggered_from, ) # Create workflow node execution repository - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -260,14 +259,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -343,14 +342,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 40a1e272a..2f9632e97 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -23,8 +23,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository @@ -156,14 +155,14 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING else: workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=workflow_triggered_from, ) # Create workflow node execution repository - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -306,16 +305,14 @@ class WorkflowAppGenerator(BaseAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -390,16 +387,14 @@ class WorkflowAppGenerator(BaseAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 2a85cd5e3..c6b326d8a 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -3,7 +3,6 @@ import time from collections.abc import Generator from typing import Optional, Union -from sqlalchemy import select from sqlalchemy.orm import Session from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -68,7 +67,6 @@ from models.workflow import ( Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, - WorkflowRun, ) logger = logging.getLogger(__name__) @@ -562,8 +560,6 @@ class WorkflowAppGenerateTaskPipeline: tts_publisher.publish(None) def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None: - workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_)) - assert workflow_run is not None invoke_from = self._application_generate_entity.invoke_from if invoke_from == InvokeFrom.SERVICE_API: created_from = WorkflowAppLogCreatedFrom.SERVICE_API @@ -576,10 +572,10 @@ class WorkflowAppGenerateTaskPipeline: return workflow_app_log = WorkflowAppLog() - workflow_app_log.tenant_id = workflow_run.tenant_id - workflow_app_log.app_id = workflow_run.app_id - workflow_app_log.workflow_id = workflow_run.workflow_id - workflow_app_log.workflow_run_id = workflow_run.id + workflow_app_log.tenant_id = self._application_generate_entity.app_config.tenant_id + workflow_app_log.app_id = self._application_generate_entity.app_config.app_id + workflow_app_log.workflow_id = workflow_execution.workflow_id + workflow_app_log.workflow_run_id = workflow_execution.id_ workflow_app_log.created_from = created_from.value workflow_app_log.created_by_role = self._created_by_role workflow_app_log.created_by = self._user_id diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 2254b3d4d..a9f0a92e5 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,6 +1,8 @@ from collections.abc import Sequence from typing import Optional +from sqlalchemy import select + from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.file import file_manager from core.model_manager import ModelInstance @@ -17,11 +19,15 @@ from core.prompt.utils.extract_thread_messages import extract_thread_messages from extensions.ext_database import db from factories import file_factory from models.model import AppMode, Conversation, Message, MessageFile -from models.workflow import WorkflowRun +from models.workflow import Workflow, WorkflowRun class TokenBufferMemory: - def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> None: + def __init__( + self, + conversation: Conversation, + model_instance: ModelInstance, + ) -> None: self.conversation = conversation self.model_instance = model_instance @@ -36,20 +42,8 @@ class TokenBufferMemory: app_record = self.conversation.app # fetch limited messages, and return reversed - query = ( - db.session.query( - Message.id, - Message.query, - Message.answer, - Message.created_at, - Message.workflow_run_id, - Message.parent_message_id, - Message.answer_tokens, - ) - .filter( - Message.conversation_id == self.conversation.id, - ) - .order_by(Message.created_at.desc()) + stmt = ( + select(Message).where(Message.conversation_id == self.conversation.id).order_by(Message.created_at.desc()) ) if message_limit and message_limit > 0: @@ -57,7 +51,9 @@ class TokenBufferMemory: else: message_limit = 500 - messages = query.limit(message_limit).all() + stmt = stmt.limit(message_limit) + + messages = db.session.scalars(stmt).all() # instead of all messages from the conversation, we only need to extract messages # that belong to the thread of last message @@ -74,18 +70,20 @@ class TokenBufferMemory: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() if files: file_extra_config = None - if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}: file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) + elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + workflow_run = db.session.scalar( + select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id) + ) + if not workflow_run: + raise ValueError(f"Workflow run not found: {message.workflow_run_id}") + workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id)) + if not workflow: + raise ValueError(f"Workflow not found: {workflow_run.workflow_id}") + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) else: - if message.workflow_run_id: - workflow_run = ( - db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first() - ) - - if workflow_run and workflow_run.workflow: - file_extra_config = FileUploadConfigManager.convert( - workflow_run.workflow.features_dict, is_vision=False - ) + raise AssertionError(f"Invalid app mode: {self.conversation.mode}") detail = ImagePromptMessageContent.DETAIL.LOW if file_extra_config and app_record: diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index a3dbce0e5..4a7e66d27 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( UnitEnum, ) from core.ops.utils import filter_none_values -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.nodes.enums import NodeType from extensions.ext_database import db from models import EndUser, WorkflowNodeExecutionTriggeredFrom @@ -123,10 +123,10 @@ class LangFuseDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, - app_id=trace_info.metadata.get("app_id"), + app_id=app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index f94e5e49d..8a559c492 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -27,7 +27,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( LangSmithRunUpdateModel, ) from core.ops.utils import filter_none_values, generate_dotted_order -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db @@ -145,10 +145,10 @@ class LangSmithDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, - app_id=trace_info.metadata.get("app_id"), + app_id=app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index 8bedea20f..fcbbc70fc 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -21,7 +21,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db @@ -160,10 +160,10 @@ class OpikDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, - app_id=trace_info.metadata.get("app_id"), + app_id=app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 3917348a9..445c6a874 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db @@ -144,10 +144,10 @@ class WeaveDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, - app_id=trace_info.metadata.get("app_id"), + app_id=app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) diff --git a/api/core/prompt/utils/extract_thread_messages.py b/api/core/prompt/utils/extract_thread_messages.py index f7aef76c8..4b883622a 100644 --- a/api/core/prompt/utils/extract_thread_messages.py +++ b/api/core/prompt/utils/extract_thread_messages.py @@ -1,10 +1,11 @@ -from typing import Any +from collections.abc import Sequence from constants import UUID_NIL +from models import Message -def extract_thread_messages(messages: list[Any]): - thread_messages = [] +def extract_thread_messages(messages: Sequence[Message]): + thread_messages: list[Message] = [] next_message = None for message in messages: diff --git a/api/core/prompt/utils/get_thread_messages_length.py b/api/core/prompt/utils/get_thread_messages_length.py index f49466db6..de64c27a7 100644 --- a/api/core/prompt/utils/get_thread_messages_length.py +++ b/api/core/prompt/utils/get_thread_messages_length.py @@ -1,3 +1,5 @@ +from sqlalchemy import select + from core.prompt.utils.extract_thread_messages import extract_thread_messages from extensions.ext_database import db from models.model import Message @@ -8,19 +10,9 @@ def get_thread_messages_length(conversation_id: str) -> int: Get the number of thread messages based on the parent message id. """ # Fetch all messages related to the conversation - query = ( - db.session.query( - Message.id, - Message.parent_message_id, - Message.answer, - ) - .filter( - Message.conversation_id == conversation_id, - ) - .order_by(Message.created_at.desc()) - ) + stmt = select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at.desc()) - messages = query.all() + messages = db.session.scalars(stmt).all() # Extract thread messages thread_messages = extract_thread_messages(messages) diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py index 645231712..052ba1c2c 100644 --- a/api/core/repositories/__init__.py +++ b/api/core/repositories/__init__.py @@ -5,8 +5,11 @@ This package contains concrete implementations of the repository interfaces defined in the core.workflow.repository package. """ +from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository __all__ = [ + "DifyCoreRepositoryFactory", + "RepositoryImportError", "SQLAlchemyWorkflowNodeExecutionRepository", ] diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py new file mode 100644 index 000000000..4118aa61c --- /dev/null +++ b/api/core/repositories/factory.py @@ -0,0 +1,224 @@ +""" +Repository factory for dynamically creating repository instances based on configuration. + +This module provides a Django-like settings system for repository implementations, +allowing users to configure different repository backends through string paths. +""" + +import importlib +import inspect +import logging +from typing import Protocol, Union + +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from models import Account, EndUser +from models.enums import WorkflowRunTriggeredFrom +from models.workflow import WorkflowNodeExecutionTriggeredFrom + +logger = logging.getLogger(__name__) + + +class RepositoryImportError(Exception): + """Raised when a repository implementation cannot be imported or instantiated.""" + + pass + + +class DifyCoreRepositoryFactory: + """ + Factory for creating repository instances based on configuration. + + This factory supports Django-like settings where repository implementations + are specified as module paths (e.g., 'module.submodule.ClassName'). + """ + + @staticmethod + def _import_class(class_path: str) -> type: + """ + Import a class from a module path string. + + Args: + class_path: Full module path to the class (e.g., 'module.submodule.ClassName') + + Returns: + The imported class + + Raises: + RepositoryImportError: If the class cannot be imported + """ + try: + module_path, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + repo_class = getattr(module, class_name) + assert isinstance(repo_class, type) + return repo_class + except (ValueError, ImportError, AttributeError) as e: + raise RepositoryImportError(f"Cannot import repository class '{class_path}': {e}") from e + + @staticmethod + def _validate_repository_interface(repository_class: type, expected_interface: type[Protocol]) -> None: # type: ignore + """ + Validate that a class implements the expected repository interface. + + Args: + repository_class: The class to validate + expected_interface: The expected interface/protocol + + Raises: + RepositoryImportError: If the class doesn't implement the interface + """ + # Check if the class has all required methods from the protocol + required_methods = [ + method + for method in dir(expected_interface) + if not method.startswith("_") and callable(getattr(expected_interface, method, None)) + ] + + missing_methods = [] + for method_name in required_methods: + if not hasattr(repository_class, method_name): + missing_methods.append(method_name) + + if missing_methods: + raise RepositoryImportError( + f"Repository class '{repository_class.__name__}' does not implement required methods " + f"{missing_methods} from interface '{expected_interface.__name__}'" + ) + + @staticmethod + def _validate_constructor_signature(repository_class: type, required_params: list[str]) -> None: + """ + Validate that a repository class constructor accepts required parameters. + + Args: + repository_class: The class to validate + required_params: List of required parameter names + + Raises: + RepositoryImportError: If the constructor doesn't accept required parameters + """ + + try: + # MyPy may flag the line below with the following error: + # + # > Accessing "__init__" on an instance is unsound, since + # > instance.__init__ could be from an incompatible subclass. + # + # Despite this, we need to ensure that the constructor of `repository_class` + # has a compatible signature. + signature = inspect.signature(repository_class.__init__) # type: ignore[misc] + param_names = list(signature.parameters.keys()) + + # Remove 'self' parameter + if "self" in param_names: + param_names.remove("self") + + missing_params = [param for param in required_params if param not in param_names] + if missing_params: + raise RepositoryImportError( + f"Repository class '{repository_class.__name__}' constructor does not accept required parameters: " + f"{missing_params}. Expected parameters: {required_params}" + ) + except Exception as e: + raise RepositoryImportError( + f"Failed to validate constructor signature for '{repository_class.__name__}': {e}" + ) from e + + @classmethod + def create_workflow_execution_repository( + cls, + session_factory: Union[sessionmaker, Engine], + user: Union[Account, EndUser], + app_id: str, + triggered_from: WorkflowRunTriggeredFrom, + ) -> WorkflowExecutionRepository: + """ + Create a WorkflowExecutionRepository instance based on configuration. + + Args: + session_factory: SQLAlchemy sessionmaker or engine + user: Account or EndUser object + app_id: Application ID + triggered_from: Source of the execution trigger + + Returns: + Configured WorkflowExecutionRepository instance + + Raises: + RepositoryImportError: If the configured repository cannot be created + """ + class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY + logger.debug(f"Creating WorkflowExecutionRepository from: {class_path}") + + try: + repository_class = cls._import_class(class_path) + cls._validate_repository_interface(repository_class, WorkflowExecutionRepository) + cls._validate_constructor_signature( + repository_class, ["session_factory", "user", "app_id", "triggered_from"] + ) + + return repository_class( # type: ignore[no-any-return] + session_factory=session_factory, + user=user, + app_id=app_id, + triggered_from=triggered_from, + ) + except RepositoryImportError: + # Re-raise our custom errors as-is + raise + except Exception as e: + logger.exception("Failed to create WorkflowExecutionRepository") + raise RepositoryImportError(f"Failed to create WorkflowExecutionRepository from '{class_path}': {e}") from e + + @classmethod + def create_workflow_node_execution_repository( + cls, + session_factory: Union[sessionmaker, Engine], + user: Union[Account, EndUser], + app_id: str, + triggered_from: WorkflowNodeExecutionTriggeredFrom, + ) -> WorkflowNodeExecutionRepository: + """ + Create a WorkflowNodeExecutionRepository instance based on configuration. + + Args: + session_factory: SQLAlchemy sessionmaker or engine + user: Account or EndUser object + app_id: Application ID + triggered_from: Source of the execution trigger + + Returns: + Configured WorkflowNodeExecutionRepository instance + + Raises: + RepositoryImportError: If the configured repository cannot be created + """ + class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY + logger.debug(f"Creating WorkflowNodeExecutionRepository from: {class_path}") + + try: + repository_class = cls._import_class(class_path) + cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository) + cls._validate_constructor_signature( + repository_class, ["session_factory", "user", "app_id", "triggered_from"] + ) + + return repository_class( # type: ignore[no-any-return] + session_factory=session_factory, + user=user, + app_id=app_id, + triggered_from=triggered_from, + ) + except RepositoryImportError: + # Re-raise our custom errors as-is + raise + except Exception as e: + logger.exception("Failed to create WorkflowNodeExecutionRepository") + raise RepositoryImportError( + f"Failed to create WorkflowNodeExecutionRepository from '{class_path}': {e}" + ) from e diff --git a/api/models/model.py b/api/models/model.py index b1007c4a7..7e9e91727 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -50,7 +50,6 @@ class AppMode(StrEnum): CHAT = "chat" ADVANCED_CHAT = "advanced-chat" AGENT_CHAT = "agent-chat" - CHANNEL = "channel" @classmethod def value_of(cls, value: str) -> "AppMode": @@ -934,7 +933,7 @@ class Message(Base): created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - workflow_run_id = db.Column(StringUUID) + workflow_run_id: Mapped[str] = db.Column(StringUUID) @property def inputs(self): diff --git a/api/repositories/__init__.py b/api/repositories/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py new file mode 100644 index 000000000..00a2d1f87 --- /dev/null +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -0,0 +1,197 @@ +""" +Service-layer repository protocol for WorkflowNodeExecutionModel operations. + +This module provides a protocol interface for service-layer operations on WorkflowNodeExecutionModel +that abstracts database queries currently done directly in service classes. This repository is +specifically designed for service-layer needs and is separate from the core domain repository. + +The service repository handles operations that require access to database-specific fields like +tenant_id, app_id, triggered_from, etc., which are not part of the core domain model. +""" + +from collections.abc import Sequence +from datetime import datetime +from typing import Optional, Protocol + +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from models.workflow import WorkflowNodeExecutionModel + + +class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol): + """ + Protocol for service-layer operations on WorkflowNodeExecutionModel. + + This repository provides database access patterns specifically needed by service classes, + handling queries that involve database-specific fields and multi-tenancy concerns. + + Key responsibilities: + - Manages database operations for workflow node executions + - Handles multi-tenant data isolation + - Provides batch processing capabilities + - Supports execution lifecycle management + + Implementation notes: + - Returns database models directly (WorkflowNodeExecutionModel) + - Handles tenant/app filtering automatically + - Provides service-specific query patterns + - Focuses on database operations without domain logic + - Supports cleanup and maintenance operations + """ + + def get_node_last_execution( + self, + tenant_id: str, + app_id: str, + workflow_id: str, + node_id: str, + ) -> Optional[WorkflowNodeExecutionModel]: + """ + Get the most recent execution for a specific node. + + This method finds the latest execution of a specific node within a workflow, + ordered by creation time. Used primarily for debugging and inspection purposes. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + workflow_id: The workflow identifier + node_id: The node identifier + + Returns: + The most recent WorkflowNodeExecutionModel for the node, or None if not found + """ + ... + + def get_executions_by_workflow_run( + self, + tenant_id: str, + app_id: str, + workflow_run_id: str, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get all node executions for a specific workflow run. + + This method retrieves all node executions that belong to a specific workflow run, + ordered by index in descending order for proper trace visualization. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + workflow_run_id: The workflow run identifier + + Returns: + A sequence of WorkflowNodeExecutionModel instances ordered by index (desc) + """ + ... + + def get_execution_by_id( + self, + execution_id: str, + tenant_id: Optional[str] = None, + ) -> Optional[WorkflowNodeExecutionModel]: + """ + Get a workflow node execution by its ID. + + This method retrieves a specific execution by its unique identifier. + Tenant filtering is optional for cases where the execution ID is globally unique. + + When `tenant_id` is None, it's the caller's responsibility to ensure proper data isolation between tenants. + If the `execution_id` comes from untrusted sources (e.g., retrieved from an API request), the caller should + set `tenant_id` to prevent horizontal privilege escalation. + + Args: + execution_id: The execution identifier + tenant_id: Optional tenant identifier for additional filtering + + Returns: + The WorkflowNodeExecutionModel if found, or None if not found + """ + ... + + def delete_expired_executions( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> int: + """ + Delete workflow node executions that are older than the specified date. + + This method is used for cleanup operations to remove expired executions + in batches to avoid overwhelming the database. + + Args: + tenant_id: The tenant identifier + before_date: Delete executions created before this date + batch_size: Maximum number of executions to delete in one batch + + Returns: + The number of executions deleted + """ + ... + + def delete_executions_by_app( + self, + tenant_id: str, + app_id: str, + batch_size: int = 1000, + ) -> int: + """ + Delete all workflow node executions for a specific app. + + This method is used when removing an app and all its related data. + Executions are deleted in batches to avoid overwhelming the database. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + batch_size: Maximum number of executions to delete in one batch + + Returns: + The total number of executions deleted + """ + ... + + def get_expired_executions_batch( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get a batch of expired workflow node executions for backup purposes. + + This method retrieves expired executions without deleting them, + allowing the caller to backup the data before deletion. + + Args: + tenant_id: The tenant identifier + before_date: Get executions created before this date + batch_size: Maximum number of executions to retrieve + + Returns: + A sequence of WorkflowNodeExecutionModel instances + """ + ... + + def delete_executions_by_ids( + self, + execution_ids: Sequence[str], + ) -> int: + """ + Delete workflow node executions by their IDs. + + This method deletes specific executions by their IDs, + typically used after backing up the data. + + This method does not perform tenant isolation checks. The caller is responsible for ensuring proper + data isolation between tenants. When execution IDs come from untrusted sources (e.g., API requests), + additional tenant validation should be implemented to prevent unauthorized access. + + Args: + execution_ids: List of execution IDs to delete + + Returns: + The number of executions deleted + """ + ... diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py new file mode 100644 index 000000000..59e7baeb7 --- /dev/null +++ b/api/repositories/api_workflow_run_repository.py @@ -0,0 +1,181 @@ +""" +API WorkflowRun Repository Protocol + +This module defines the protocol for service-layer WorkflowRun operations. +The repository provides an abstraction layer for WorkflowRun database operations +used by service classes, separating service-layer concerns from core domain logic. + +Key Features: +- Paginated workflow run queries with filtering +- Bulk deletion operations with OSS backup support +- Multi-tenant data isolation +- Expired record cleanup with data retention +- Service-layer specific query patterns + +Usage: + This protocol should be used by service classes that need to perform + WorkflowRun database operations. It provides a clean interface that + hides implementation details and supports dependency injection. + +Example: + ```python + from repositories.dify_api_repository_factory import DifyAPIRepositoryFactory + + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + + # Get paginated workflow runs + runs = repo.get_paginated_workflow_runs( + tenant_id="tenant-123", + app_id="app-456", + triggered_from="debugging", + limit=20 + ) + ``` +""" + +from collections.abc import Sequence +from datetime import datetime +from typing import Optional, Protocol + +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.workflow import WorkflowRun + + +class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): + """ + Protocol for service-layer WorkflowRun repository operations. + + This protocol defines the interface for WorkflowRun database operations + that are specific to service-layer needs, including pagination, filtering, + and bulk operations with data backup support. + """ + + def get_paginated_workflow_runs( + self, + tenant_id: str, + app_id: str, + triggered_from: str, + limit: int = 20, + last_id: Optional[str] = None, + ) -> InfiniteScrollPagination: + """ + Get paginated workflow runs with filtering. + + Retrieves workflow runs for a specific app and trigger source with + cursor-based pagination support. Used primarily for debugging and + workflow run listing in the UI. + + Args: + tenant_id: Tenant identifier for multi-tenant isolation + app_id: Application identifier + triggered_from: Filter by trigger source (e.g., "debugging", "app-run") + limit: Maximum number of records to return (default: 20) + last_id: Cursor for pagination - ID of the last record from previous page + + Returns: + InfiniteScrollPagination object containing: + - data: List of WorkflowRun objects + - limit: Applied limit + - has_more: Boolean indicating if more records exist + + Raises: + ValueError: If last_id is provided but the corresponding record doesn't exist + """ + ... + + def get_workflow_run_by_id( + self, + tenant_id: str, + app_id: str, + run_id: str, + ) -> Optional[WorkflowRun]: + """ + Get a specific workflow run by ID. + + Retrieves a single workflow run with tenant and app isolation. + Used for workflow run detail views and execution tracking. + + Args: + tenant_id: Tenant identifier for multi-tenant isolation + app_id: Application identifier + run_id: Workflow run identifier + + Returns: + WorkflowRun object if found, None otherwise + """ + ... + + def get_expired_runs_batch( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> Sequence[WorkflowRun]: + """ + Get a batch of expired workflow runs for cleanup. + + Retrieves workflow runs created before the specified date for + cleanup operations. Used by scheduled tasks to remove old data + while maintaining data retention policies. + + Args: + tenant_id: Tenant identifier for multi-tenant isolation + before_date: Only return runs created before this date + batch_size: Maximum number of records to return + + Returns: + Sequence of WorkflowRun objects to be processed for cleanup + """ + ... + + def delete_runs_by_ids( + self, + run_ids: Sequence[str], + ) -> int: + """ + Delete workflow runs by their IDs. + + Performs bulk deletion of workflow runs by ID. This method should + be used after backing up the data to OSS storage for retention. + + Args: + run_ids: Sequence of workflow run IDs to delete + + Returns: + Number of records actually deleted + + Note: + This method performs hard deletion. Ensure data is backed up + to OSS storage before calling this method for compliance with + data retention policies. + """ + ... + + def delete_runs_by_app( + self, + tenant_id: str, + app_id: str, + batch_size: int = 1000, + ) -> int: + """ + Delete all workflow runs for a specific app. + + Performs bulk deletion of all workflow runs associated with an app. + Used during app cleanup operations. Processes records in batches + to avoid memory issues and long-running transactions. + + Args: + tenant_id: Tenant identifier for multi-tenant isolation + app_id: Application identifier + batch_size: Number of records to process in each batch + + Returns: + Total number of records deleted across all batches + + Note: + This method performs hard deletion without backup. Use with caution + and ensure proper data retention policies are followed. + """ + ... diff --git a/api/repositories/factory.py b/api/repositories/factory.py new file mode 100644 index 000000000..0a0adbf2c --- /dev/null +++ b/api/repositories/factory.py @@ -0,0 +1,103 @@ +""" +DifyAPI Repository Factory for creating repository instances. + +This factory is specifically designed for DifyAPI repositories that handle +service-layer operations with dependency injection patterns. +""" + +import logging + +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from core.repositories import DifyCoreRepositoryFactory, RepositoryImportError +from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository +from repositories.api_workflow_run_repository import APIWorkflowRunRepository + +logger = logging.getLogger(__name__) + + +class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory): + """ + Factory for creating DifyAPI repository instances based on configuration. + + This factory handles the creation of repositories that are specifically designed + for service-layer operations and use dependency injection with sessionmaker + for better testability and separation of concerns. + """ + + @classmethod + def create_api_workflow_node_execution_repository( + cls, session_maker: sessionmaker + ) -> DifyAPIWorkflowNodeExecutionRepository: + """ + Create a DifyAPIWorkflowNodeExecutionRepository instance based on configuration. + + This repository is designed for service-layer operations and uses dependency injection + with a sessionmaker for better testability and separation of concerns. It provides + database access patterns specifically needed by service classes, handling queries + that involve database-specific fields and multi-tenancy concerns. + + Args: + session_maker: SQLAlchemy sessionmaker to inject for database session management. + + Returns: + Configured DifyAPIWorkflowNodeExecutionRepository instance + + Raises: + RepositoryImportError: If the configured repository cannot be imported or instantiated + """ + class_path = dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY + logger.debug(f"Creating DifyAPIWorkflowNodeExecutionRepository from: {class_path}") + + try: + repository_class = cls._import_class(class_path) + cls._validate_repository_interface(repository_class, DifyAPIWorkflowNodeExecutionRepository) + # Service repository requires session_maker parameter + cls._validate_constructor_signature(repository_class, ["session_maker"]) + + return repository_class(session_maker=session_maker) # type: ignore[no-any-return] + except RepositoryImportError: + # Re-raise our custom errors as-is + raise + except Exception as e: + logger.exception("Failed to create DifyAPIWorkflowNodeExecutionRepository") + raise RepositoryImportError( + f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}" + ) from e + + @classmethod + def create_api_workflow_run_repository(cls, session_maker: sessionmaker) -> APIWorkflowRunRepository: + """ + Create an APIWorkflowRunRepository instance based on configuration. + + This repository is designed for service-layer WorkflowRun operations and uses dependency + injection with a sessionmaker for better testability and separation of concerns. It provides + database access patterns specifically needed by service classes for workflow run management, + including pagination, filtering, and bulk operations. + + Args: + session_maker: SQLAlchemy sessionmaker to inject for database session management. + + Returns: + Configured APIWorkflowRunRepository instance + + Raises: + RepositoryImportError: If the configured repository cannot be imported or instantiated + """ + class_path = dify_config.API_WORKFLOW_RUN_REPOSITORY + logger.debug(f"Creating APIWorkflowRunRepository from: {class_path}") + + try: + repository_class = cls._import_class(class_path) + cls._validate_repository_interface(repository_class, APIWorkflowRunRepository) + # Service repository requires session_maker parameter + cls._validate_constructor_signature(repository_class, ["session_maker"]) + + return repository_class(session_maker=session_maker) # type: ignore[no-any-return] + except RepositoryImportError: + # Re-raise our custom errors as-is + raise + except Exception as e: + logger.exception("Failed to create APIWorkflowRunRepository") + raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py new file mode 100644 index 000000000..e6a23ddf9 --- /dev/null +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -0,0 +1,290 @@ +""" +SQLAlchemy implementation of WorkflowNodeExecutionServiceRepository. + +This module provides a concrete implementation of the service repository protocol +using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations. +""" + +from collections.abc import Sequence +from datetime import datetime +from typing import Optional + +from sqlalchemy import delete, desc, select +from sqlalchemy.orm import Session, sessionmaker + +from models.workflow import WorkflowNodeExecutionModel +from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository + + +class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository): + """ + SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository. + + This repository provides service-layer database operations for WorkflowNodeExecutionModel + using SQLAlchemy 2.0 style queries. It implements the DifyAPIWorkflowNodeExecutionRepository + protocol with the following features: + + - Multi-tenancy data isolation through tenant_id filtering + - Direct database model operations without domain conversion + - Batch processing for efficient large-scale operations + - Optimized query patterns for common access patterns + - Dependency injection for better testability and maintainability + - Session management and transaction handling with proper cleanup + - Maintenance operations for data lifecycle management + - Thread-safe database operations using session-per-request pattern + """ + + def __init__(self, session_maker: sessionmaker[Session]): + """ + Initialize the repository with a sessionmaker. + + Args: + session_maker: SQLAlchemy sessionmaker for creating database sessions + """ + self._session_maker = session_maker + + def get_node_last_execution( + self, + tenant_id: str, + app_id: str, + workflow_id: str, + node_id: str, + ) -> Optional[WorkflowNodeExecutionModel]: + """ + Get the most recent execution for a specific node. + + This method replicates the query pattern from WorkflowService.get_node_last_run() + using SQLAlchemy 2.0 style syntax. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + workflow_id: The workflow identifier + node_id: The node identifier + + Returns: + The most recent WorkflowNodeExecutionModel for the node, or None if not found + """ + stmt = ( + select(WorkflowNodeExecutionModel) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.app_id == app_id, + WorkflowNodeExecutionModel.workflow_id == workflow_id, + WorkflowNodeExecutionModel.node_id == node_id, + ) + .order_by(desc(WorkflowNodeExecutionModel.created_at)) + .limit(1) + ) + + with self._session_maker() as session: + return session.scalar(stmt) + + def get_executions_by_workflow_run( + self, + tenant_id: str, + app_id: str, + workflow_run_id: str, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get all node executions for a specific workflow run. + + This method replicates the query pattern from WorkflowRunService.get_workflow_run_node_executions() + using SQLAlchemy 2.0 style syntax. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + workflow_run_id: The workflow run identifier + + Returns: + A sequence of WorkflowNodeExecutionModel instances ordered by index (desc) + """ + stmt = ( + select(WorkflowNodeExecutionModel) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.app_id == app_id, + WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, + ) + .order_by(desc(WorkflowNodeExecutionModel.index)) + ) + + with self._session_maker() as session: + return session.execute(stmt).scalars().all() + + def get_execution_by_id( + self, + execution_id: str, + tenant_id: Optional[str] = None, + ) -> Optional[WorkflowNodeExecutionModel]: + """ + Get a workflow node execution by its ID. + + This method replicates the query pattern from WorkflowDraftVariableService + and WorkflowService.single_step_run_workflow_node() using SQLAlchemy 2.0 style syntax. + + When `tenant_id` is None, it's the caller's responsibility to ensure proper data isolation between tenants. + If the `execution_id` comes from untrusted sources (e.g., retrieved from an API request), the caller should + set `tenant_id` to prevent horizontal privilege escalation. + + Args: + execution_id: The execution identifier + tenant_id: Optional tenant identifier for additional filtering + + Returns: + The WorkflowNodeExecutionModel if found, or None if not found + """ + stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution_id) + + # Add tenant filtering if provided + if tenant_id is not None: + stmt = stmt.where(WorkflowNodeExecutionModel.tenant_id == tenant_id) + + with self._session_maker() as session: + return session.scalar(stmt) + + def delete_expired_executions( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> int: + """ + Delete workflow node executions that are older than the specified date. + + Args: + tenant_id: The tenant identifier + before_date: Delete executions created before this date + batch_size: Maximum number of executions to delete in one batch + + Returns: + The number of executions deleted + """ + total_deleted = 0 + + while True: + with self._session_maker() as session: + # Find executions to delete in batches + stmt = ( + select(WorkflowNodeExecutionModel.id) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.created_at < before_date, + ) + .limit(batch_size) + ) + + execution_ids = session.execute(stmt).scalars().all() + if not execution_ids: + break + + # Delete the batch + delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids)) + result = session.execute(delete_stmt) + session.commit() + total_deleted += result.rowcount + + # If we deleted fewer than the batch size, we're done + if len(execution_ids) < batch_size: + break + + return total_deleted + + def delete_executions_by_app( + self, + tenant_id: str, + app_id: str, + batch_size: int = 1000, + ) -> int: + """ + Delete all workflow node executions for a specific app. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + batch_size: Maximum number of executions to delete in one batch + + Returns: + The total number of executions deleted + """ + total_deleted = 0 + + while True: + with self._session_maker() as session: + # Find executions to delete in batches + stmt = ( + select(WorkflowNodeExecutionModel.id) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.app_id == app_id, + ) + .limit(batch_size) + ) + + execution_ids = session.execute(stmt).scalars().all() + if not execution_ids: + break + + # Delete the batch + delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids)) + result = session.execute(delete_stmt) + session.commit() + total_deleted += result.rowcount + + # If we deleted fewer than the batch size, we're done + if len(execution_ids) < batch_size: + break + + return total_deleted + + def get_expired_executions_batch( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get a batch of expired workflow node executions for backup purposes. + + Args: + tenant_id: The tenant identifier + before_date: Get executions created before this date + batch_size: Maximum number of executions to retrieve + + Returns: + A sequence of WorkflowNodeExecutionModel instances + """ + stmt = ( + select(WorkflowNodeExecutionModel) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.created_at < before_date, + ) + .limit(batch_size) + ) + + with self._session_maker() as session: + return session.execute(stmt).scalars().all() + + def delete_executions_by_ids( + self, + execution_ids: Sequence[str], + ) -> int: + """ + Delete workflow node executions by their IDs. + + Args: + execution_ids: List of execution IDs to delete + + Returns: + The number of executions deleted + """ + if not execution_ids: + return 0 + + with self._session_maker() as session: + stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids)) + result = session.execute(stmt) + session.commit() + return result.rowcount diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py new file mode 100644 index 000000000..bb66bb3a9 --- /dev/null +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -0,0 +1,202 @@ +""" +SQLAlchemy API WorkflowRun Repository Implementation + +This module provides the SQLAlchemy-based implementation of the APIWorkflowRunRepository +protocol. It handles service-layer WorkflowRun database operations using SQLAlchemy 2.0 +style queries with proper session management and multi-tenant data isolation. + +Key Features: +- SQLAlchemy 2.0 style queries for modern database operations +- Cursor-based pagination for efficient large dataset handling +- Bulk operations with batch processing for performance +- Multi-tenant data isolation and security +- Proper session management with dependency injection + +Implementation Notes: +- Uses sessionmaker for consistent session management +- Implements cursor-based pagination using created_at timestamps +- Provides efficient bulk deletion with batch processing +- Maintains data consistency with proper transaction handling +""" + +import logging +from collections.abc import Sequence +from datetime import datetime +from typing import Optional, cast + +from sqlalchemy import delete, select +from sqlalchemy.orm import Session, sessionmaker + +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.workflow import WorkflowRun + +logger = logging.getLogger(__name__) + + +class DifyAPISQLAlchemyWorkflowRunRepository: + """ + SQLAlchemy implementation of APIWorkflowRunRepository. + + Provides service-layer WorkflowRun database operations using SQLAlchemy 2.0 + style queries. Supports dependency injection through sessionmaker and + maintains proper multi-tenant data isolation. + + Args: + session_maker: SQLAlchemy sessionmaker instance for database connections + """ + + def __init__(self, session_maker: sessionmaker[Session]) -> None: + """ + Initialize the repository with a sessionmaker. + + Args: + session_maker: SQLAlchemy sessionmaker for database connections + """ + self._session_maker = session_maker + + def get_paginated_workflow_runs( + self, + tenant_id: str, + app_id: str, + triggered_from: str, + limit: int = 20, + last_id: Optional[str] = None, + ) -> InfiniteScrollPagination: + """ + Get paginated workflow runs with filtering. + + Implements cursor-based pagination using created_at timestamps for + efficient handling of large datasets. Filters by tenant, app, and + trigger source for proper data isolation. + """ + with self._session_maker() as session: + # Build base query with filters + base_stmt = select(WorkflowRun).where( + WorkflowRun.tenant_id == tenant_id, + WorkflowRun.app_id == app_id, + WorkflowRun.triggered_from == triggered_from, + ) + + if last_id: + # Get the last workflow run for cursor-based pagination + last_run_stmt = base_stmt.where(WorkflowRun.id == last_id) + last_workflow_run = session.scalar(last_run_stmt) + + if not last_workflow_run: + raise ValueError("Last workflow run not exists") + + # Get records created before the last run's timestamp + base_stmt = base_stmt.where( + WorkflowRun.created_at < last_workflow_run.created_at, + WorkflowRun.id != last_workflow_run.id, + ) + + # First page - get most recent records + workflow_runs = session.scalars(base_stmt.order_by(WorkflowRun.created_at.desc()).limit(limit + 1)).all() + + # Check if there are more records for pagination + has_more = len(workflow_runs) > limit + if has_more: + workflow_runs = workflow_runs[:-1] + + return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) + + def get_workflow_run_by_id( + self, + tenant_id: str, + app_id: str, + run_id: str, + ) -> Optional[WorkflowRun]: + """ + Get a specific workflow run by ID with tenant and app isolation. + """ + with self._session_maker() as session: + stmt = select(WorkflowRun).where( + WorkflowRun.tenant_id == tenant_id, + WorkflowRun.app_id == app_id, + WorkflowRun.id == run_id, + ) + return cast(Optional[WorkflowRun], session.scalar(stmt)) + + def get_expired_runs_batch( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> Sequence[WorkflowRun]: + """ + Get a batch of expired workflow runs for cleanup operations. + """ + with self._session_maker() as session: + stmt = ( + select(WorkflowRun) + .where( + WorkflowRun.tenant_id == tenant_id, + WorkflowRun.created_at < before_date, + ) + .limit(batch_size) + ) + return cast(Sequence[WorkflowRun], session.scalars(stmt).all()) + + def delete_runs_by_ids( + self, + run_ids: Sequence[str], + ) -> int: + """ + Delete workflow runs by their IDs using bulk deletion. + """ + if not run_ids: + return 0 + + with self._session_maker() as session: + stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids)) + result = session.execute(stmt) + session.commit() + + deleted_count = cast(int, result.rowcount) + logger.info(f"Deleted {deleted_count} workflow runs by IDs") + return deleted_count + + def delete_runs_by_app( + self, + tenant_id: str, + app_id: str, + batch_size: int = 1000, + ) -> int: + """ + Delete all workflow runs for a specific app in batches. + """ + total_deleted = 0 + + while True: + with self._session_maker() as session: + # Get a batch of run IDs to delete + stmt = ( + select(WorkflowRun.id) + .where( + WorkflowRun.tenant_id == tenant_id, + WorkflowRun.app_id == app_id, + ) + .limit(batch_size) + ) + run_ids = session.scalars(stmt).all() + + if not run_ids: + break + + # Delete the batch + delete_stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids)) + result = session.execute(delete_stmt) + session.commit() + + batch_deleted = result.rowcount + total_deleted += batch_deleted + + logger.info(f"Deleted batch of {batch_deleted} workflow runs for app {app_id}") + + # If we deleted fewer records than the batch size, we're done + if batch_deleted < batch_size: + break + + logger.info(f"Total deleted {total_deleted} workflow runs for app {app_id}") + return total_deleted diff --git a/api/services/app_service.py b/api/services/app_service.py index d08462d00..db0f8cd41 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -47,8 +47,6 @@ class AppService: filters.append(App.mode == AppMode.ADVANCED_CHAT.value) elif args["mode"] == "agent-chat": filters.append(App.mode == AppMode.AGENT_CHAT.value) - elif args["mode"] == "channel": - filters.append(App.mode == AppMode.CHANNEL.value) if args.get("is_created_by_me", False): filters.append(App.created_by == user_id) diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index 1fd560d58..ddd16b2e0 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -6,7 +6,7 @@ from concurrent.futures import ThreadPoolExecutor import click from flask import Flask, current_app -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder @@ -14,7 +14,7 @@ from extensions.ext_database import db from extensions.ext_storage import storage from models.account import Tenant from models.model import App, Conversation, Message -from models.workflow import WorkflowNodeExecutionModel, WorkflowRun +from repositories.factory import DifyAPIRepositoryFactory from services.billing_service import BillingService logger = logging.getLogger(__name__) @@ -105,84 +105,99 @@ class ClearFreePlanTenantExpiredLogs: ) ) - while True: - with Session(db.engine).no_autoflush as session: - workflow_node_executions = ( - session.query(WorkflowNodeExecutionModel) - .filter( - WorkflowNodeExecutionModel.tenant_id == tenant_id, - WorkflowNodeExecutionModel.created_at - < datetime.datetime.now() - datetime.timedelta(days=days), - ) - .limit(batch) - .all() - ) - - if len(workflow_node_executions) == 0: - break - - # save workflow node executions - storage.save( - f"free_plan_tenant_expired_logs/" - f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}" - f"-{time.time()}.json", - json.dumps( - jsonable_encoder(workflow_node_executions), - ).encode("utf-8"), - ) - - workflow_node_execution_ids = [ - workflow_node_execution.id for workflow_node_execution in workflow_node_executions - ] - - # delete workflow node executions - session.query(WorkflowNodeExecutionModel).filter( - WorkflowNodeExecutionModel.id.in_(workflow_node_execution_ids), - ).delete(synchronize_session=False) - session.commit() - - click.echo( - click.style( - f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}" - f" workflow node executions for tenant {tenant_id}" - ) - ) + # Process expired workflow node executions with backup + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker) + before_date = datetime.datetime.now() - datetime.timedelta(days=days) + total_deleted = 0 while True: - with Session(db.engine).no_autoflush as session: - workflow_runs = ( - session.query(WorkflowRun) - .filter( - WorkflowRun.tenant_id == tenant_id, - WorkflowRun.created_at < datetime.datetime.now() - datetime.timedelta(days=days), - ) - .limit(batch) - .all() + # Get a batch of expired executions for backup + workflow_node_executions = node_execution_repo.get_expired_executions_batch( + tenant_id=tenant_id, + before_date=before_date, + batch_size=batch, + ) + + if len(workflow_node_executions) == 0: + break + + # Save workflow node executions to storage + storage.save( + f"free_plan_tenant_expired_logs/" + f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}" + f"-{time.time()}.json", + json.dumps( + jsonable_encoder(workflow_node_executions), + ).encode("utf-8"), + ) + + # Extract IDs for deletion + workflow_node_execution_ids = [ + workflow_node_execution.id for workflow_node_execution in workflow_node_executions + ] + + # Delete the backed up executions + deleted_count = node_execution_repo.delete_executions_by_ids(workflow_node_execution_ids) + total_deleted += deleted_count + + click.echo( + click.style( + f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}" + f" workflow node executions for tenant {tenant_id}" ) + ) - if len(workflow_runs) == 0: - break + # If we got fewer than the batch size, we're done + if len(workflow_node_executions) < batch: + break - # save workflow runs + # Process expired workflow runs with backup + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + before_date = datetime.datetime.now() - datetime.timedelta(days=days) + total_deleted = 0 - storage.save( - f"free_plan_tenant_expired_logs/" - f"{tenant_id}/workflow_runs/{datetime.datetime.now().strftime('%Y-%m-%d')}" - f"-{time.time()}.json", - json.dumps( - jsonable_encoder( - [workflow_run.to_dict() for workflow_run in workflow_runs], - ), - ).encode("utf-8"), + while True: + # Get a batch of expired workflow runs for backup + workflow_runs = workflow_run_repo.get_expired_runs_batch( + tenant_id=tenant_id, + before_date=before_date, + batch_size=batch, + ) + + if len(workflow_runs) == 0: + break + + # Save workflow runs to storage + storage.save( + f"free_plan_tenant_expired_logs/" + f"{tenant_id}/workflow_runs/{datetime.datetime.now().strftime('%Y-%m-%d')}" + f"-{time.time()}.json", + json.dumps( + jsonable_encoder( + [workflow_run.to_dict() for workflow_run in workflow_runs], + ), + ).encode("utf-8"), + ) + + # Extract IDs for deletion + workflow_run_ids = [workflow_run.id for workflow_run in workflow_runs] + + # Delete the backed up workflow runs + deleted_count = workflow_run_repo.delete_runs_by_ids(workflow_run_ids) + total_deleted += deleted_count + + click.echo( + click.style( + f"[{datetime.datetime.now()}] Processed {len(workflow_run_ids)}" + f" workflow runs for tenant {tenant_id}" ) + ) - workflow_run_ids = [workflow_run.id for workflow_run in workflow_runs] - - # delete workflow runs - session.query(WorkflowRun).filter( - WorkflowRun.id.in_(workflow_run_ids), - ).delete(synchronize_session=False) - session.commit() + # If we got fewer than the batch size, we're done + if len(workflow_runs) < batch: + break @classmethod def process(cls, days: int, batch: int, tenant_ids: list[str]): diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 44fd72b5e..f306e1f06 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -5,9 +5,9 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import Any, ClassVar -from sqlalchemy import Engine, orm, select +from sqlalchemy import Engine, orm from sqlalchemy.dialects.postgresql import insert -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.sql.expression import and_, or_ from core.app.entities.app_invoke_entities import InvokeFrom @@ -25,7 +25,8 @@ from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable from models import App, Conversation from models.enums import DraftVariableType -from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable +from models.workflow import Workflow, WorkflowDraftVariable, is_system_variable_editable +from repositories.factory import DifyAPIRepositoryFactory _logger = logging.getLogger(__name__) @@ -117,7 +118,24 @@ class WorkflowDraftVariableService: _session: Session def __init__(self, session: Session) -> None: + """ + Initialize the WorkflowDraftVariableService with a SQLAlchemy session. + + Args: + session (Session): The SQLAlchemy session used to execute database queries. + The provided session must be bound to an `Engine` object, not a specific `Connection`. + + Raises: + AssertionError: If the provided session is not bound to an `Engine` object. + """ self._session = session + engine = session.get_bind() + # Ensure the session is bound to a engine. + assert isinstance(engine, Engine) + session_maker = sessionmaker(bind=engine, expire_on_commit=False) + self._api_node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker + ) def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None: return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first() @@ -248,8 +266,7 @@ class WorkflowDraftVariableService: _logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name) return None - query = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == variable.node_execution_id) - node_exec = self._session.scalars(query).first() + node_exec = self._api_node_execution_repo.get_execution_by_id(variable.node_execution_id) if node_exec is None: _logger.warning( "Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s", @@ -298,6 +315,8 @@ class WorkflowDraftVariableService: def reset_variable(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None: variable_type = variable.get_variable_type() + if variable_type == DraftVariableType.SYS and not is_system_variable_editable(variable.name): + raise VariableResetError(f"cannot reset system variable, variable_id={variable.id}") if variable_type == DraftVariableType.CONVERSATION: return self._reset_conv_var(workflow, variable) else: diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 483c0d308..e43999a8c 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -2,9 +2,9 @@ import threading from collections.abc import Sequence from typing import Optional +from sqlalchemy.orm import sessionmaker + import contexts -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import OrderConfig from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import ( @@ -15,10 +15,18 @@ from models import ( WorkflowRun, WorkflowRunTriggeredFrom, ) -from models.workflow import WorkflowNodeExecutionTriggeredFrom +from repositories.factory import DifyAPIRepositoryFactory class WorkflowRunService: + def __init__(self): + """Initialize WorkflowRunService with repository dependencies.""" + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker + ) + self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + def get_paginate_advanced_chat_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination: """ Get advanced chat app workflow run list @@ -62,45 +70,16 @@ class WorkflowRunService: :param args: request args """ limit = int(args.get("limit", 20)) + last_id = args.get("last_id") - base_query = db.session.query(WorkflowRun).filter( - WorkflowRun.tenant_id == app_model.tenant_id, - WorkflowRun.app_id == app_model.id, - WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value, + return self._workflow_run_repo.get_paginated_workflow_runs( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value, + limit=limit, + last_id=last_id, ) - if args.get("last_id"): - last_workflow_run = base_query.filter( - WorkflowRun.id == args.get("last_id"), - ).first() - - if not last_workflow_run: - raise ValueError("Last workflow run not exists") - - workflow_runs = ( - base_query.filter( - WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id - ) - .order_by(WorkflowRun.created_at.desc()) - .limit(limit) - .all() - ) - else: - workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all() - - has_more = False - if len(workflow_runs) == limit: - current_page_first_workflow_run = workflow_runs[-1] - rest_count = base_query.filter( - WorkflowRun.created_at < current_page_first_workflow_run.created_at, - WorkflowRun.id != current_page_first_workflow_run.id, - ).count() - - if rest_count > 0: - has_more = True - - return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) - def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]: """ Get workflow run detail @@ -108,18 +87,12 @@ class WorkflowRunService: :param app_model: app model :param run_id: workflow run id """ - workflow_run = ( - db.session.query(WorkflowRun) - .filter( - WorkflowRun.tenant_id == app_model.tenant_id, - WorkflowRun.app_id == app_model.id, - WorkflowRun.id == run_id, - ) - .first() + return self._workflow_run_repo.get_workflow_run_by_id( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + run_id=run_id, ) - return workflow_run - def get_workflow_run_node_executions( self, app_model: App, @@ -137,17 +110,13 @@ class WorkflowRunService: if not workflow_run: return [] - repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=db.engine, - user=user, + # Get tenant_id from user + tenant_id = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id + if tenant_id is None: + raise ValueError("User tenant_id cannot be None") + + return self._node_execution_service_repo.get_executions_by_workflow_run( + tenant_id=tenant_id, app_id=app_model.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + workflow_run_id=run_id, ) - - # Use the repository to get the database models directly - order_config = OrderConfig(order_by=["index"], order_direction="desc") - workflow_node_executions = repository.get_db_models_by_workflow_run( - workflow_run_id=run_id, order_config=order_config - ) - - return workflow_node_executions diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 2be57fd51..0149d5034 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -7,13 +7,13 @@ from typing import Any, Optional from uuid import uuid4 from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from core.app.app_config.entities import VariableEntityType from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.file import File -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.variables import Variable from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool @@ -41,6 +41,7 @@ from models.workflow import ( WorkflowNodeExecutionTriggeredFrom, WorkflowType, ) +from repositories.factory import DifyAPIRepositoryFactory from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError from services.workflow.workflow_converter import WorkflowConverter @@ -57,21 +58,32 @@ class WorkflowService: Workflow Service """ + def __init__(self, session_maker: sessionmaker | None = None): + """Initialize WorkflowService with repository dependencies.""" + if session_maker is None: + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker + ) + def get_node_last_run(self, app_model: App, workflow: Workflow, node_id: str) -> WorkflowNodeExecutionModel | None: - # TODO(QuantumGhost): This query is not fully covered by index. - criteria = ( - WorkflowNodeExecutionModel.tenant_id == app_model.tenant_id, - WorkflowNodeExecutionModel.app_id == app_model.id, - WorkflowNodeExecutionModel.workflow_id == workflow.id, - WorkflowNodeExecutionModel.node_id == node_id, + """ + Get the most recent execution for a specific node. + + Args: + app_model: The application model + workflow: The workflow model + node_id: The node identifier + + Returns: + The most recent WorkflowNodeExecutionModel for the node, or None if not found + """ + return self._node_execution_service_repo.get_node_last_execution( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + workflow_id=workflow.id, + node_id=node_id, ) - node_exec = ( - db.session.query(WorkflowNodeExecutionModel) - .filter(*criteria) - .order_by(WorkflowNodeExecutionModel.created_at.desc()) - .first() - ) - return node_exec def is_workflow_exist(self, app_model: App) -> bool: return ( @@ -396,7 +408,7 @@ class WorkflowService: node_execution.workflow_id = draft_workflow.id # Create repository and save the node execution - repository = SQLAlchemyWorkflowNodeExecutionRepository( + repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=db.engine, user=account, app_id=app_model.id, @@ -404,8 +416,9 @@ class WorkflowService: ) repository.save(node_execution) - # Convert node_execution to WorkflowNodeExecution after save - workflow_node_execution = repository.to_db_model(node_execution) + workflow_node_execution = self._node_execution_service_repo.get_execution_by_id(node_execution.id) + if workflow_node_execution is None: + raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving") with Session(bind=db.engine) as session, session.begin(): draft_var_saver = DraftVariableSaver( @@ -418,6 +431,7 @@ class WorkflowService: ) draft_var_saver.save(process_data=node_execution.process_data, outputs=node_execution.outputs) session.commit() + return workflow_node_execution def run_free_workflow_node( @@ -429,7 +443,7 @@ class WorkflowService: # run draft workflow node start_at = time.perf_counter() - workflow_node_execution = self._handle_node_run_result( + node_execution = self._handle_node_run_result( invoke_node_fn=lambda: WorkflowEntry.run_free_node( node_id=node_id, node_data=node_data, @@ -441,7 +455,7 @@ class WorkflowService: node_id=node_id, ) - return workflow_node_execution + return node_execution def _handle_node_run_result( self, diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 4a62cb74b..179adcbd6 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -6,6 +6,7 @@ import click from celery import shared_task # type: ignore from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import sessionmaker from extensions.ext_database import db from models import ( @@ -31,7 +32,8 @@ from models import ( ) from models.tools import WorkflowToolProvider from models.web import PinnedConversation, SavedMessage -from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecutionModel, WorkflowRun +from models.workflow import ConversationVariable, Workflow, WorkflowAppLog +from repositories.factory import DifyAPIRepositoryFactory @shared_task(queue="app_deletion", bind=True, max_retries=3) @@ -189,30 +191,32 @@ def _delete_app_workflows(tenant_id: str, app_id: str): def _delete_app_workflow_runs(tenant_id: str, app_id: str): - def del_workflow_run(workflow_run_id: str): - db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).delete(synchronize_session=False) + """Delete all workflow runs for an app using the service repository.""" + session_maker = sessionmaker(bind=db.engine) + workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) - _delete_records( - """select id from workflow_runs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", - {"tenant_id": tenant_id, "app_id": app_id}, - del_workflow_run, - "workflow run", + deleted_count = workflow_run_repo.delete_runs_by_app( + tenant_id=tenant_id, + app_id=app_id, + batch_size=1000, ) + logging.info(f"Deleted {deleted_count} workflow runs for app {app_id}") + def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): - def del_workflow_node_execution(workflow_node_execution_id: str): - db.session.query(WorkflowNodeExecutionModel).filter( - WorkflowNodeExecutionModel.id == workflow_node_execution_id - ).delete(synchronize_session=False) + """Delete all workflow node executions for an app using the service repository.""" + session_maker = sessionmaker(bind=db.engine) + node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker) - _delete_records( - """select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""", - {"tenant_id": tenant_id, "app_id": app_id}, - del_workflow_node_execution, - "workflow node execution", + deleted_count = node_execution_repo.delete_executions_by_app( + tenant_id=tenant_id, + app_id=app_id, + batch_size=1000, ) + logging.info(f"Deleted {deleted_count} workflow node executions for app {app_id}") + def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def del_workflow_app_log(workflow_app_log_id: str): diff --git a/api/tests/unit_tests/core/repositories/__init__.py b/api/tests/unit_tests/core/repositories/__init__.py new file mode 100644 index 000000000..c65d7da61 --- /dev/null +++ b/api/tests/unit_tests/core/repositories/__init__.py @@ -0,0 +1 @@ +# Unit tests for core repositories module diff --git a/api/tests/unit_tests/core/repositories/test_factory.py b/api/tests/unit_tests/core/repositories/test_factory.py new file mode 100644 index 000000000..fce4a6fb6 --- /dev/null +++ b/api/tests/unit_tests/core/repositories/test_factory.py @@ -0,0 +1,455 @@ +""" +Unit tests for the RepositoryFactory. + +This module tests the factory pattern implementation for creating repository instances +based on configuration, including error handling and validation. +""" + +from unittest.mock import MagicMock, patch + +import pytest +from pytest_mock import MockerFixture +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from models import Account, EndUser +from models.enums import WorkflowRunTriggeredFrom +from models.workflow import WorkflowNodeExecutionTriggeredFrom + + +class TestRepositoryFactory: + """Test cases for RepositoryFactory.""" + + def test_import_class_success(self): + """Test successful class import.""" + # Test importing a real class + class_path = "unittest.mock.MagicMock" + result = DifyCoreRepositoryFactory._import_class(class_path) + assert result is MagicMock + + def test_import_class_invalid_path(self): + """Test import with invalid module path.""" + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory._import_class("invalid.module.path") + assert "Cannot import repository class" in str(exc_info.value) + + def test_import_class_invalid_class_name(self): + """Test import with invalid class name.""" + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory._import_class("unittest.mock.NonExistentClass") + assert "Cannot import repository class" in str(exc_info.value) + + def test_import_class_malformed_path(self): + """Test import with malformed path (no dots).""" + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory._import_class("invalidpath") + assert "Cannot import repository class" in str(exc_info.value) + + def test_validate_repository_interface_success(self): + """Test successful interface validation.""" + + # Create a mock class that implements the required methods + class MockRepository: + def save(self): + pass + + def get_by_id(self): + pass + + # Create a mock interface with the same methods + class MockInterface: + def save(self): + pass + + def get_by_id(self): + pass + + # Should not raise an exception + DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface) + + def test_validate_repository_interface_missing_methods(self): + """Test interface validation with missing methods.""" + + # Create a mock class that doesn't implement all required methods + class IncompleteRepository: + def save(self): + pass + + # Missing get_by_id method + + # Create a mock interface with required methods + class MockInterface: + def save(self): + pass + + def get_by_id(self): + pass + + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface) + assert "does not implement required methods" in str(exc_info.value) + assert "get_by_id" in str(exc_info.value) + + def test_validate_constructor_signature_success(self): + """Test successful constructor signature validation.""" + + class MockRepository: + def __init__(self, session_factory, user, app_id, triggered_from): + pass + + # Should not raise an exception + DifyCoreRepositoryFactory._validate_constructor_signature( + MockRepository, ["session_factory", "user", "app_id", "triggered_from"] + ) + + def test_validate_constructor_signature_missing_params(self): + """Test constructor validation with missing parameters.""" + + class IncompleteRepository: + def __init__(self, session_factory, user): + # Missing app_id and triggered_from parameters + pass + + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory._validate_constructor_signature( + IncompleteRepository, ["session_factory", "user", "app_id", "triggered_from"] + ) + assert "does not accept required parameters" in str(exc_info.value) + assert "app_id" in str(exc_info.value) + assert "triggered_from" in str(exc_info.value) + + def test_validate_constructor_signature_inspection_error(self, mocker: MockerFixture): + """Test constructor validation when inspection fails.""" + # Mock inspect.signature to raise an exception + mocker.patch("inspect.signature", side_effect=Exception("Inspection failed")) + + class MockRepository: + def __init__(self, session_factory): + pass + + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"]) + assert "Failed to validate constructor signature" in str(exc_info.value) + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_execution_repository_success(self, mock_config, mocker: MockerFixture): + """Test successful creation of WorkflowExecutionRepository.""" + # Setup mock configuration + mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + + # Create mock dependencies + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=Account) + app_id = "test-app-id" + triggered_from = WorkflowRunTriggeredFrom.APP_RUN + + # Mock the imported class to be a valid repository + mock_repository_class = MagicMock() + mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository) + mock_repository_class.return_value = mock_repository_instance + + # Mock the validation methods + with ( + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), + patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), + ): + result = DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id=app_id, + triggered_from=triggered_from, + ) + + # Verify the repository was created with correct parameters + mock_repository_class.assert_called_once_with( + session_factory=mock_session_factory, + user=mock_user, + app_id=app_id, + triggered_from=triggered_from, + ) + assert result is mock_repository_instance + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_execution_repository_import_error(self, mock_config): + """Test WorkflowExecutionRepository creation with import error.""" + # Setup mock configuration with invalid class path + mock_config.WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass" + + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=Account) + + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + assert "Cannot import repository class" in str(exc_info.value) + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_execution_repository_validation_error(self, mock_config, mocker: MockerFixture): + """Test WorkflowExecutionRepository creation with validation error.""" + # Setup mock configuration + mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=Account) + + # Mock import to succeed but validation to fail + mock_repository_class = MagicMock() + with ( + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object( + DifyCoreRepositoryFactory, + "_validate_repository_interface", + side_effect=RepositoryImportError("Interface validation failed"), + ), + ): + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + assert "Interface validation failed" in str(exc_info.value) + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_execution_repository_instantiation_error(self, mock_config, mocker: MockerFixture): + """Test WorkflowExecutionRepository creation with instantiation error.""" + # Setup mock configuration + mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=Account) + + # Mock import and validation to succeed but instantiation to fail + mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed")) + with ( + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), + patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), + ): + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value) + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_node_execution_repository_success(self, mock_config, mocker: MockerFixture): + """Test successful creation of WorkflowNodeExecutionRepository.""" + # Setup mock configuration + mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + + # Create mock dependencies + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=EndUser) + app_id = "test-app-id" + triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN + + # Mock the imported class to be a valid repository + mock_repository_class = MagicMock() + mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository) + mock_repository_class.return_value = mock_repository_instance + + # Mock the validation methods + with ( + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), + patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), + ): + result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id=app_id, + triggered_from=triggered_from, + ) + + # Verify the repository was created with correct parameters + mock_repository_class.assert_called_once_with( + session_factory=mock_session_factory, + user=mock_user, + app_id=app_id, + triggered_from=triggered_from, + ) + assert result is mock_repository_instance + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_node_execution_repository_import_error(self, mock_config): + """Test WorkflowNodeExecutionRepository creation with import error.""" + # Setup mock configuration with invalid class path + mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass" + + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=EndUser) + + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + assert "Cannot import repository class" in str(exc_info.value) + + def test_repository_import_error_exception(self): + """Test RepositoryImportError exception.""" + error_message = "Test error message" + exception = RepositoryImportError(error_message) + assert str(exception) == error_message + assert isinstance(exception, Exception) + + @patch("core.repositories.factory.dify_config") + def test_create_with_engine_instead_of_sessionmaker(self, mock_config, mocker: MockerFixture): + """Test repository creation with Engine instead of sessionmaker.""" + # Setup mock configuration + mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + + # Create mock dependencies with Engine instead of sessionmaker + mock_engine = MagicMock(spec=Engine) + mock_user = MagicMock(spec=Account) + + # Mock the imported class to be a valid repository + mock_repository_class = MagicMock() + mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository) + mock_repository_class.return_value = mock_repository_instance + + # Mock the validation methods + with ( + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), + patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), + ): + result = DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=mock_engine, # Using Engine instead of sessionmaker + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + # Verify the repository was created with the Engine + mock_repository_class.assert_called_once_with( + session_factory=mock_engine, + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + assert result is mock_repository_instance + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_node_execution_repository_validation_error(self, mock_config): + """Test WorkflowNodeExecutionRepository creation with validation error.""" + # Setup mock configuration + mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=EndUser) + + # Mock import to succeed but validation to fail + mock_repository_class = MagicMock() + with ( + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object( + DifyCoreRepositoryFactory, + "_validate_repository_interface", + side_effect=RepositoryImportError("Interface validation failed"), + ), + ): + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + assert "Interface validation failed" in str(exc_info.value) + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config): + """Test WorkflowNodeExecutionRepository creation with instantiation error.""" + # Setup mock configuration + mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=EndUser) + + # Mock import and validation to succeed but instantiation to fail + mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed")) + with ( + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), + patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), + ): + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value) + + def test_validate_repository_interface_with_private_methods(self): + """Test interface validation ignores private methods.""" + + # Create a mock class with private methods + class MockRepository: + def save(self): + pass + + def get_by_id(self): + pass + + def _private_method(self): + pass + + # Create a mock interface with private methods + class MockInterface: + def save(self): + pass + + def get_by_id(self): + pass + + def _private_method(self): + pass + + # Should not raise an exception (private methods are ignored) + DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface) + + def test_validate_constructor_signature_with_extra_params(self): + """Test constructor validation with extra parameters (should pass).""" + + class MockRepository: + def __init__(self, session_factory, user, app_id, triggered_from, extra_param=None): + pass + + # Should not raise an exception (extra parameters are allowed) + DifyCoreRepositoryFactory._validate_constructor_signature( + MockRepository, ["session_factory", "user", "app_id", "triggered_from"] + ) + + def test_validate_constructor_signature_with_kwargs(self): + """Test constructor validation with **kwargs (current implementation doesn't support this).""" + + class MockRepository: + def __init__(self, session_factory, user, **kwargs): + pass + + # Current implementation doesn't handle **kwargs, so this should raise an exception + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory._validate_constructor_signature( + MockRepository, ["session_factory", "user", "app_id", "triggered_from"] + ) + assert "does not accept required parameters" in str(exc_info.value) + assert "app_id" in str(exc_info.value) + assert "triggered_from" in str(exc_info.value) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py b/api/tests/unit_tests/services/workflow/test_workflow_deletion.py index 223020c2c..2c87eaf80 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_deletion.py @@ -10,7 +10,8 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE @pytest.fixture def workflow_setup(): - workflow_service = WorkflowService() + mock_session_maker = MagicMock() + workflow_service = WorkflowService(mock_session_maker) session = MagicMock(spec=Session) tenant_id = "test-tenant-id" workflow_id = "test-workflow-id" diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py index c5c9cf105..8b1348b75 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -1,14 +1,14 @@ import dataclasses import secrets -from unittest import mock -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pytest +from sqlalchemy import Engine from sqlalchemy.orm import Session from core.variables import StringSegment from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.nodes import NodeType +from core.workflow.nodes.enums import NodeType from models.enums import DraftVariableType from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable from services.workflow_draft_variable_service import ( @@ -18,13 +18,25 @@ from services.workflow_draft_variable_service import ( ) +@pytest.fixture +def mock_engine() -> Engine: + return Mock(spec=Engine) + + +@pytest.fixture +def mock_session(mock_engine) -> Session: + mock_session = Mock(spec=Session) + mock_session.get_bind.return_value = mock_engine + return mock_session + + class TestDraftVariableSaver: def _get_test_app_id(self): suffix = secrets.token_hex(6) return f"test_app_id_{suffix}" def test__should_variable_be_visible(self): - mock_session = mock.MagicMock(spec=Session) + mock_session = MagicMock(spec=Session) test_app_id = self._get_test_app_id() saver = DraftVariableSaver( session=mock_session, @@ -70,7 +82,7 @@ class TestDraftVariableSaver: ), ] - mock_session = mock.MagicMock(spec=Session) + mock_session = MagicMock(spec=Session) test_app_id = self._get_test_app_id() saver = DraftVariableSaver( session=mock_session, @@ -105,9 +117,8 @@ class TestWorkflowDraftVariableService: conversation_variables=[], ) - def test_reset_conversation_variable(self): + def test_reset_conversation_variable(self, mock_session): """Test resetting a conversation variable""" - mock_session = Mock(spec=Session) service = WorkflowDraftVariableService(mock_session) test_app_id = self._get_test_app_id() @@ -131,9 +142,8 @@ class TestWorkflowDraftVariableService: mock_reset_conv.assert_called_once_with(workflow, variable) assert result == expected_result - def test_reset_node_variable_with_no_execution_id(self): + def test_reset_node_variable_with_no_execution_id(self, mock_session): """Test resetting a node variable with no execution ID - should delete variable""" - mock_session = Mock(spec=Session) service = WorkflowDraftVariableService(mock_session) test_app_id = self._get_test_app_id() @@ -158,11 +168,26 @@ class TestWorkflowDraftVariableService: mock_session.flush.assert_called_once() assert result is None - def test_reset_node_variable_with_missing_execution_record(self): + def test_reset_node_variable_with_missing_execution_record( + self, + mock_engine, + mock_session, + monkeypatch, + ): """Test resetting a node variable when execution record doesn't exist""" - mock_session = Mock(spec=Session) + mock_repo_session = Mock(spec=Session) + + mock_session_maker = MagicMock() + # Mock the context manager protocol for sessionmaker + mock_session_maker.return_value.__enter__.return_value = mock_repo_session + mock_session_maker.return_value.__exit__.return_value = None + monkeypatch.setattr("services.workflow_draft_variable_service.sessionmaker", mock_session_maker) service = WorkflowDraftVariableService(mock_session) + # Mock the repository to return None (no execution record found) + service._api_node_execution_repo = Mock() + service._api_node_execution_repo.get_execution_by_id.return_value = None + test_app_id = self._get_test_app_id() workflow = self._create_test_workflow(test_app_id) @@ -171,24 +196,41 @@ class TestWorkflowDraftVariableService: variable = WorkflowDraftVariable.new_node_variable( app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id" ) - - # Mock session.scalars to return None (no execution record found) - mock_scalars = Mock() - mock_scalars.first.return_value = None - mock_session.scalars.return_value = mock_scalars + # Variable is editable by default from factory method result = service._reset_node_var_or_sys_var(workflow, variable) + mock_session_maker.assert_called_once_with(bind=mock_engine, expire_on_commit=False) # Should delete the variable and return None mock_session.delete.assert_called_once_with(instance=variable) mock_session.flush.assert_called_once() assert result is None - def test_reset_node_variable_with_valid_execution_record(self): + def test_reset_node_variable_with_valid_execution_record( + self, + mock_session, + monkeypatch, + ): """Test resetting a node variable with valid execution record - should restore from execution""" - mock_session = Mock(spec=Session) + mock_repo_session = Mock(spec=Session) + + mock_session_maker = MagicMock() + # Mock the context manager protocol for sessionmaker + mock_session_maker.return_value.__enter__.return_value = mock_repo_session + mock_session_maker.return_value.__exit__.return_value = None + mock_session_maker = monkeypatch.setattr( + "services.workflow_draft_variable_service.sessionmaker", mock_session_maker + ) service = WorkflowDraftVariableService(mock_session) + # Create mock execution record + mock_execution = Mock(spec=WorkflowNodeExecutionModel) + mock_execution.outputs_dict = {"test_var": "output_value"} + + # Mock the repository to return the execution record + service._api_node_execution_repo = Mock() + service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution + test_app_id = self._get_test_app_id() workflow = self._create_test_workflow(test_app_id) @@ -197,16 +239,7 @@ class TestWorkflowDraftVariableService: variable = WorkflowDraftVariable.new_node_variable( app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id" ) - - # Create mock execution record - mock_execution = Mock(spec=WorkflowNodeExecutionModel) - mock_execution.process_data_dict = {"test_var": "process_value"} - mock_execution.outputs_dict = {"test_var": "output_value"} - - # Mock session.scalars to return the execution record - mock_scalars = Mock() - mock_scalars.first.return_value = mock_execution - mock_session.scalars.return_value = mock_scalars + # Variable is editable by default from factory method # Mock workflow methods mock_node_config = {"type": "test_node"} @@ -224,9 +257,8 @@ class TestWorkflowDraftVariableService: # Should return the updated variable assert result == variable - def test_reset_non_editable_system_variable_raises_error(self): + def test_reset_non_editable_system_variable_raises_error(self, mock_session): """Test that resetting a non-editable system variable raises an error""" - mock_session = Mock(spec=Session) service = WorkflowDraftVariableService(mock_session) test_app_id = self._get_test_app_id() @@ -242,24 +274,13 @@ class TestWorkflowDraftVariableService: editable=False, # Non-editable system variable ) - # Mock the service to properly check system variable editability - with patch.object(service, "reset_variable") as mock_reset: + with pytest.raises(VariableResetError) as exc_info: + service.reset_variable(workflow, variable) + assert "cannot reset system variable" in str(exc_info.value) + assert f"variable_id={variable.id}" in str(exc_info.value) - def side_effect(wf, var): - if var.get_variable_type() == DraftVariableType.SYS and not is_system_variable_editable(var.name): - raise VariableResetError(f"cannot reset system variable, variable_id={var.id}") - return var - - mock_reset.side_effect = side_effect - - with pytest.raises(VariableResetError) as exc_info: - service.reset_variable(workflow, variable) - assert "cannot reset system variable" in str(exc_info.value) - assert f"variable_id={variable.id}" in str(exc_info.value) - - def test_reset_editable_system_variable_succeeds(self): + def test_reset_editable_system_variable_succeeds(self, mock_session): """Test that resetting an editable system variable succeeds""" - mock_session = Mock(spec=Session) service = WorkflowDraftVariableService(mock_session) test_app_id = self._get_test_app_id() @@ -279,10 +300,9 @@ class TestWorkflowDraftVariableService: mock_execution = Mock(spec=WorkflowNodeExecutionModel) mock_execution.outputs_dict = {"sys.files": "[]"} - # Mock session.scalars to return the execution record - mock_scalars = Mock() - mock_scalars.first.return_value = mock_execution - mock_session.scalars.return_value = mock_scalars + # Mock the repository to return the execution record + service._api_node_execution_repo = Mock() + service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution result = service._reset_node_var_or_sys_var(workflow, variable) @@ -291,9 +311,8 @@ class TestWorkflowDraftVariableService: assert variable.last_edited_at is None mock_session.flush.assert_called() - def test_reset_query_system_variable_succeeds(self): + def test_reset_query_system_variable_succeeds(self, mock_session): """Test that resetting query system variable (another editable one) succeeds""" - mock_session = Mock(spec=Session) service = WorkflowDraftVariableService(mock_session) test_app_id = self._get_test_app_id() @@ -313,10 +332,9 @@ class TestWorkflowDraftVariableService: mock_execution = Mock(spec=WorkflowNodeExecutionModel) mock_execution.outputs_dict = {"sys.query": "reset query"} - # Mock session.scalars to return the execution record - mock_scalars = Mock() - mock_scalars.first.return_value = mock_execution - mock_session.scalars.return_value = mock_scalars + # Mock the repository to return the execution record + service._api_node_execution_repo = Mock() + service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution result = service._reset_node_var_or_sys_var(workflow, variable) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py new file mode 100644 index 000000000..32d2f8b7e --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -0,0 +1,288 @@ +from datetime import datetime +from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from models.workflow import WorkflowNodeExecutionModel +from repositories.sqlalchemy_api_workflow_node_execution_repository import ( + DifyAPISQLAlchemyWorkflowNodeExecutionRepository, +) + + +class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: + @pytest.fixture + def repository(self): + mock_session_maker = MagicMock() + return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker=mock_session_maker) + + @pytest.fixture + def mock_execution(self): + execution = MagicMock(spec=WorkflowNodeExecutionModel) + execution.id = str(uuid4()) + execution.tenant_id = "tenant-123" + execution.app_id = "app-456" + execution.workflow_id = "workflow-789" + execution.workflow_run_id = "run-101" + execution.node_id = "node-202" + execution.index = 1 + execution.created_at = "2023-01-01T00:00:00Z" + return execution + + def test_get_node_last_execution_found(self, repository, mock_execution): + """Test getting the last execution for a node when it exists.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = mock_execution + + # Act + result = repository.get_node_last_execution( + tenant_id="tenant-123", + app_id="app-456", + workflow_id="workflow-789", + node_id="node-202", + ) + + # Assert + assert result == mock_execution + mock_session.scalar.assert_called_once() + # Verify the query was constructed correctly + call_args = mock_session.scalar.call_args[0][0] + assert hasattr(call_args, "compile") # It's a SQLAlchemy statement + + def test_get_node_last_execution_not_found(self, repository): + """Test getting the last execution for a node when it doesn't exist.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + + # Act + result = repository.get_node_last_execution( + tenant_id="tenant-123", + app_id="app-456", + workflow_id="workflow-789", + node_id="node-202", + ) + + # Assert + assert result is None + mock_session.scalar.assert_called_once() + + def test_get_executions_by_workflow_run(self, repository, mock_execution): + """Test getting all executions for a workflow run.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + executions = [mock_execution] + mock_session.execute.return_value.scalars.return_value.all.return_value = executions + + # Act + result = repository.get_executions_by_workflow_run( + tenant_id="tenant-123", + app_id="app-456", + workflow_run_id="run-101", + ) + + # Assert + assert result == executions + mock_session.execute.assert_called_once() + # Verify the query was constructed correctly + call_args = mock_session.execute.call_args[0][0] + assert hasattr(call_args, "compile") # It's a SQLAlchemy statement + + def test_get_executions_by_workflow_run_empty(self, repository): + """Test getting executions for a workflow run when none exist.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + mock_session.execute.return_value.scalars.return_value.all.return_value = [] + + # Act + result = repository.get_executions_by_workflow_run( + tenant_id="tenant-123", + app_id="app-456", + workflow_run_id="run-101", + ) + + # Assert + assert result == [] + mock_session.execute.assert_called_once() + + def test_get_execution_by_id_found(self, repository, mock_execution): + """Test getting execution by ID when it exists.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = mock_execution + + # Act + result = repository.get_execution_by_id(mock_execution.id) + + # Assert + assert result == mock_execution + mock_session.scalar.assert_called_once() + + def test_get_execution_by_id_not_found(self, repository): + """Test getting execution by ID when it doesn't exist.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + + # Act + result = repository.get_execution_by_id("non-existent-id") + + # Assert + assert result is None + mock_session.scalar.assert_called_once() + + def test_repository_implements_protocol(self, repository): + """Test that the repository implements the required protocol methods.""" + # Verify all protocol methods are implemented + assert hasattr(repository, "get_node_last_execution") + assert hasattr(repository, "get_executions_by_workflow_run") + assert hasattr(repository, "get_execution_by_id") + + # Verify methods are callable + assert callable(repository.get_node_last_execution) + assert callable(repository.get_executions_by_workflow_run) + assert callable(repository.get_execution_by_id) + assert callable(repository.delete_expired_executions) + assert callable(repository.delete_executions_by_app) + assert callable(repository.get_expired_executions_batch) + assert callable(repository.delete_executions_by_ids) + + def test_delete_expired_executions(self, repository): + """Test deleting expired executions.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + + # Mock the select query to return some IDs first time, then empty to stop loop + execution_ids = ["id1", "id2"] # Less than batch_size to trigger break + + # Mock execute method to handle both select and delete statements + def mock_execute(stmt): + mock_result = MagicMock() + # For select statements, return execution IDs + if hasattr(stmt, "limit"): # This is our select statement + mock_result.scalars.return_value.all.return_value = execution_ids + else: # This is our delete statement + mock_result.rowcount = 2 + return mock_result + + mock_session.execute.side_effect = mock_execute + + before_date = datetime(2023, 1, 1) + + # Act + result = repository.delete_expired_executions( + tenant_id="tenant-123", + before_date=before_date, + batch_size=1000, + ) + + # Assert + assert result == 2 + assert mock_session.execute.call_count == 2 # One select call, one delete call + mock_session.commit.assert_called_once() + + def test_delete_executions_by_app(self, repository): + """Test deleting executions by app.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + + # Mock the select query to return some IDs first time, then empty to stop loop + execution_ids = ["id1", "id2"] + + # Mock execute method to handle both select and delete statements + def mock_execute(stmt): + mock_result = MagicMock() + # For select statements, return execution IDs + if hasattr(stmt, "limit"): # This is our select statement + mock_result.scalars.return_value.all.return_value = execution_ids + else: # This is our delete statement + mock_result.rowcount = 2 + return mock_result + + mock_session.execute.side_effect = mock_execute + + # Act + result = repository.delete_executions_by_app( + tenant_id="tenant-123", + app_id="app-456", + batch_size=1000, + ) + + # Assert + assert result == 2 + assert mock_session.execute.call_count == 2 # One select call, one delete call + mock_session.commit.assert_called_once() + + def test_get_expired_executions_batch(self, repository): + """Test getting expired executions batch for backup.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + + # Create mock execution objects + mock_execution1 = MagicMock() + mock_execution1.id = "exec-1" + mock_execution2 = MagicMock() + mock_execution2.id = "exec-2" + + mock_session.execute.return_value.scalars.return_value.all.return_value = [mock_execution1, mock_execution2] + + before_date = datetime(2023, 1, 1) + + # Act + result = repository.get_expired_executions_batch( + tenant_id="tenant-123", + before_date=before_date, + batch_size=1000, + ) + + # Assert + assert len(result) == 2 + assert result[0].id == "exec-1" + assert result[1].id == "exec-2" + mock_session.execute.assert_called_once() + + def test_delete_executions_by_ids(self, repository): + """Test deleting executions by IDs.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + + # Mock the delete query result + mock_result = MagicMock() + mock_result.rowcount = 3 + mock_session.execute.return_value = mock_result + + execution_ids = ["id1", "id2", "id3"] + + # Act + result = repository.delete_executions_by_ids(execution_ids) + + # Assert + assert result == 3 + mock_session.execute.assert_called_once() + mock_session.commit.assert_called_once() + + def test_delete_executions_by_ids_empty_list(self, repository): + """Test deleting executions with empty ID list.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + + # Act + result = repository.delete_executions_by_ids([]) + + # Assert + assert result == 0 + mock_session.query.assert_not_called() + mock_session.commit.assert_not_called() diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py index 13393668e..9700cbaf0 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py @@ -10,7 +10,8 @@ from services.workflow_service import WorkflowService class TestWorkflowService: @pytest.fixture def workflow_service(self): - return WorkflowService() + mock_session_maker = MagicMock() + return WorkflowService(mock_session_maker) @pytest.fixture def mock_app(self): diff --git a/docker/.env.example b/docker/.env.example index 84b6152f0..dabd66f28 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -799,6 +799,19 @@ WORKFLOW_FILE_UPLOAD_LIMIT=10 # hybrid: Save new data to object storage, read from both object storage and RDBMS WORKFLOW_NODE_EXECUTION_STORAGE=rdbms +# Repository configuration +# Core workflow execution repository implementation +CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository + +# Core workflow node execution repository implementation +CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository + +# API workflow node execution repository implementation +API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository + +# API workflow run repository implementation +API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository + # HTTP request node in workflow configuration HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index ac9953aa3..61362ed9f 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -354,6 +354,10 @@ x-shared-env: &shared-api-worker-env WORKFLOW_PARALLEL_DEPTH_LIMIT: ${WORKFLOW_PARALLEL_DEPTH_LIMIT:-3} WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10} WORKFLOW_NODE_EXECUTION_STORAGE: ${WORKFLOW_NODE_EXECUTION_STORAGE:-rdbms} + CORE_WORKFLOW_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository} + CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository} + API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository} + API_WORKFLOW_RUN_REPOSITORY: ${API_WORKFLOW_RUN_REPOSITORY:-repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True}