feat(api/repo): Allow to config repository implementation (#21458)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
-LAN-
2025-07-14 14:54:38 +08:00
committed by GitHub
parent b27c540379
commit 6eb155ae69
38 changed files with 2361 additions and 329 deletions

View File

@@ -449,6 +449,19 @@ MAX_VARIABLE_SIZE=204800
# hybrid: Save new data to object storage, read from both object storage and RDBMS
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
# Repository configuration
# Core workflow execution repository implementation
CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
# Core workflow node execution repository implementation
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
# API workflow node execution repository implementation
API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
# API workflow run repository implementation
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
# App configuration
APP_MAX_EXECUTION_TIME=1200
APP_MAX_ACTIVE_REQUESTS=0

View File

@@ -537,6 +537,33 @@ class WorkflowNodeExecutionConfig(BaseSettings):
)
class RepositoryConfig(BaseSettings):
"""
Configuration for repository implementations
"""
CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field(
description="Repository implementation for WorkflowExecution. Specify as a module path",
default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository",
)
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
description="Repository implementation for WorkflowNodeExecution. Specify as a module path",
default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository",
)
API_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
description="Service-layer repository implementation for WorkflowNodeExecutionModel operations. "
"Specify as a module path",
default="repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository",
)
API_WORKFLOW_RUN_REPOSITORY: str = Field(
description="Service-layer repository implementation for WorkflowRun operations. Specify as a module path",
default="repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository",
)
class AuthConfig(BaseSettings):
"""
Configuration for authentication and OAuth
@@ -903,6 +930,7 @@ class FeatureConfig(
MultiModalTransferConfig,
PositionConfig,
RagEtlConfig,
RepositoryConfig,
SecurityConfig,
ToolConfig,
UpdateConfig,

View File

@@ -35,8 +35,6 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[
raise AppNotFoundError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode == AppMode.CHANNEL:
raise AppNotFoundError()
if mode is not None:
if isinstance(mode, list):

View File

@@ -3,7 +3,7 @@ import logging
from dateutil.parser import isoparse
from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import InternalServerError
from controllers.service_api import api
@@ -30,7 +30,7 @@ from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs import helper
from libs.helper import TimestampField
from models.model import App, AppMode, EndUser
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
from services.workflow_app_service import WorkflowAppService
@@ -63,7 +63,15 @@ class WorkflowRunDetailApi(Resource):
if app_mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]:
raise NotWorkflowAppError()
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
# Use repository to get workflow run
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
workflow_run = workflow_run_repo.get_workflow_run_by_id(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
run_id=workflow_run_id,
)
return workflow_run

View File

@@ -3,6 +3,8 @@ import logging
import uuid
from typing import Optional, Union, cast
from sqlalchemy import select
from core.agent.entities import AgentEntity, AgentToolEntity
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
@@ -417,12 +419,15 @@ class BaseAgentRunner(AppRunner):
if isinstance(prompt_message, SystemPromptMessage):
result.append(prompt_message)
messages: list[Message] = (
db.session.query(Message)
.filter(
Message.conversation_id == self.message.conversation_id,
)
messages = (
(
db.session.execute(
select(Message)
.where(Message.conversation_id == self.message.conversation_id)
.order_by(Message.created_at.desc())
)
)
.scalars()
.all()
)

View File

@@ -25,8 +25,7 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaverFactory,
)
@@ -183,14 +182,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
# Create workflow node execution repository
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
@@ -260,14 +259,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
@@ -343,14 +342,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,

View File

@@ -23,8 +23,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
@@ -156,14 +155,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
# Create workflow node execution repository
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
@@ -306,16 +305,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
@@ -390,16 +387,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,

View File

@@ -3,7 +3,6 @@ import time
from collections.abc import Generator
from typing import Optional, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
@@ -68,7 +67,6 @@ from models.workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowRun,
)
logger = logging.getLogger(__name__)
@@ -562,8 +560,6 @@ class WorkflowAppGenerateTaskPipeline:
tts_publisher.publish(None)
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None:
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_))
assert workflow_run is not None
invoke_from = self._application_generate_entity.invoke_from
if invoke_from == InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
@@ -576,10 +572,10 @@ class WorkflowAppGenerateTaskPipeline:
return
workflow_app_log = WorkflowAppLog()
workflow_app_log.tenant_id = workflow_run.tenant_id
workflow_app_log.app_id = workflow_run.app_id
workflow_app_log.workflow_id = workflow_run.workflow_id
workflow_app_log.workflow_run_id = workflow_run.id
workflow_app_log.tenant_id = self._application_generate_entity.app_config.tenant_id
workflow_app_log.app_id = self._application_generate_entity.app_config.app_id
workflow_app_log.workflow_id = workflow_execution.workflow_id
workflow_app_log.workflow_run_id = workflow_execution.id_
workflow_app_log.created_from = created_from.value
workflow_app_log.created_by_role = self._created_by_role
workflow_app_log.created_by = self._user_id

View File

@@ -1,6 +1,8 @@
from collections.abc import Sequence
from typing import Optional
from sqlalchemy import select
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file import file_manager
from core.model_manager import ModelInstance
@@ -17,11 +19,15 @@ from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from factories import file_factory
from models.model import AppMode, Conversation, Message, MessageFile
from models.workflow import WorkflowRun
from models.workflow import Workflow, WorkflowRun
class TokenBufferMemory:
def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> None:
def __init__(
self,
conversation: Conversation,
model_instance: ModelInstance,
) -> None:
self.conversation = conversation
self.model_instance = model_instance
@@ -36,20 +42,8 @@ class TokenBufferMemory:
app_record = self.conversation.app
# fetch limited messages, and return reversed
query = (
db.session.query(
Message.id,
Message.query,
Message.answer,
Message.created_at,
Message.workflow_run_id,
Message.parent_message_id,
Message.answer_tokens,
)
.filter(
Message.conversation_id == self.conversation.id,
)
.order_by(Message.created_at.desc())
stmt = (
select(Message).where(Message.conversation_id == self.conversation.id).order_by(Message.created_at.desc())
)
if message_limit and message_limit > 0:
@@ -57,7 +51,9 @@ class TokenBufferMemory:
else:
message_limit = 500
messages = query.limit(message_limit).all()
stmt = stmt.limit(message_limit)
messages = db.session.scalars(stmt).all()
# instead of all messages from the conversation, we only need to extract messages
# that belong to the thread of last message
@@ -74,18 +70,20 @@ class TokenBufferMemory:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if files:
file_extra_config = None
if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow_run = db.session.scalar(
select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id)
)
if not workflow_run:
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
if not workflow:
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
else:
if message.workflow_run_id:
workflow_run = (
db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first()
)
if workflow_run and workflow_run.workflow:
file_extra_config = FileUploadConfigManager.convert(
workflow_run.workflow.features_dict, is_vision=False
)
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
detail = ImagePromptMessageContent.DETAIL.LOW
if file_extra_config and app_record:

View File

@@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
UnitEnum,
)
from core.ops.utils import filter_none_values
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
@@ -123,10 +123,10 @@ class LangFuseDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=trace_info.metadata.get("app_id"),
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)

View File

@@ -27,7 +27,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
LangSmithRunUpdateModel,
)
from core.ops.utils import filter_none_values, generate_dotted_order
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
@@ -145,10 +145,10 @@ class LangSmithDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=trace_info.metadata.get("app_id"),
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)

View File

@@ -21,7 +21,7 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
@@ -160,10 +160,10 @@ class OpikDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=trace_info.metadata.get("app_id"),
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)

View File

@@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
@@ -144,10 +144,10 @@ class WeaveDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=trace_info.metadata.get("app_id"),
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)

View File

@@ -1,10 +1,11 @@
from typing import Any
from collections.abc import Sequence
from constants import UUID_NIL
from models import Message
def extract_thread_messages(messages: list[Any]):
thread_messages = []
def extract_thread_messages(messages: Sequence[Message]):
thread_messages: list[Message] = []
next_message = None
for message in messages:

View File

@@ -1,3 +1,5 @@
from sqlalchemy import select
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from models.model import Message
@@ -8,19 +10,9 @@ def get_thread_messages_length(conversation_id: str) -> int:
Get the number of thread messages based on the parent message id.
"""
# Fetch all messages related to the conversation
query = (
db.session.query(
Message.id,
Message.parent_message_id,
Message.answer,
)
.filter(
Message.conversation_id == conversation_id,
)
.order_by(Message.created_at.desc())
)
stmt = select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at.desc())
messages = query.all()
messages = db.session.scalars(stmt).all()
# Extract thread messages
thread_messages = extract_thread_messages(messages)

View File

@@ -5,8 +5,11 @@ This package contains concrete implementations of the repository interfaces
defined in the core.workflow.repository package.
"""
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
__all__ = [
"DifyCoreRepositoryFactory",
"RepositoryImportError",
"SQLAlchemyWorkflowNodeExecutionRepository",
]

View File

@@ -0,0 +1,224 @@
"""
Repository factory for dynamically creating repository instances based on configuration.
This module provides a Django-like settings system for repository implementations,
allowing users to configure different repository backends through string paths.
"""
import importlib
import inspect
import logging
from typing import Protocol, Union
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from models import Account, EndUser
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
class RepositoryImportError(Exception):
"""Raised when a repository implementation cannot be imported or instantiated."""
pass
class DifyCoreRepositoryFactory:
"""
Factory for creating repository instances based on configuration.
This factory supports Django-like settings where repository implementations
are specified as module paths (e.g., 'module.submodule.ClassName').
"""
@staticmethod
def _import_class(class_path: str) -> type:
"""
Import a class from a module path string.
Args:
class_path: Full module path to the class (e.g., 'module.submodule.ClassName')
Returns:
The imported class
Raises:
RepositoryImportError: If the class cannot be imported
"""
try:
module_path, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_path)
repo_class = getattr(module, class_name)
assert isinstance(repo_class, type)
return repo_class
except (ValueError, ImportError, AttributeError) as e:
raise RepositoryImportError(f"Cannot import repository class '{class_path}': {e}") from e
@staticmethod
def _validate_repository_interface(repository_class: type, expected_interface: type[Protocol]) -> None: # type: ignore
"""
Validate that a class implements the expected repository interface.
Args:
repository_class: The class to validate
expected_interface: The expected interface/protocol
Raises:
RepositoryImportError: If the class doesn't implement the interface
"""
# Check if the class has all required methods from the protocol
required_methods = [
method
for method in dir(expected_interface)
if not method.startswith("_") and callable(getattr(expected_interface, method, None))
]
missing_methods = []
for method_name in required_methods:
if not hasattr(repository_class, method_name):
missing_methods.append(method_name)
if missing_methods:
raise RepositoryImportError(
f"Repository class '{repository_class.__name__}' does not implement required methods "
f"{missing_methods} from interface '{expected_interface.__name__}'"
)
@staticmethod
def _validate_constructor_signature(repository_class: type, required_params: list[str]) -> None:
"""
Validate that a repository class constructor accepts required parameters.
Args:
repository_class: The class to validate
required_params: List of required parameter names
Raises:
RepositoryImportError: If the constructor doesn't accept required parameters
"""
try:
# MyPy may flag the line below with the following error:
#
# > Accessing "__init__" on an instance is unsound, since
# > instance.__init__ could be from an incompatible subclass.
#
# Despite this, we need to ensure that the constructor of `repository_class`
# has a compatible signature.
signature = inspect.signature(repository_class.__init__) # type: ignore[misc]
param_names = list(signature.parameters.keys())
# Remove 'self' parameter
if "self" in param_names:
param_names.remove("self")
missing_params = [param for param in required_params if param not in param_names]
if missing_params:
raise RepositoryImportError(
f"Repository class '{repository_class.__name__}' constructor does not accept required parameters: "
f"{missing_params}. Expected parameters: {required_params}"
)
except Exception as e:
raise RepositoryImportError(
f"Failed to validate constructor signature for '{repository_class.__name__}': {e}"
) from e
@classmethod
def create_workflow_execution_repository(
cls,
session_factory: Union[sessionmaker, Engine],
user: Union[Account, EndUser],
app_id: str,
triggered_from: WorkflowRunTriggeredFrom,
) -> WorkflowExecutionRepository:
"""
Create a WorkflowExecutionRepository instance based on configuration.
Args:
session_factory: SQLAlchemy sessionmaker or engine
user: Account or EndUser object
app_id: Application ID
triggered_from: Source of the execution trigger
Returns:
Configured WorkflowExecutionRepository instance
Raises:
RepositoryImportError: If the configured repository cannot be created
"""
class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY
logger.debug(f"Creating WorkflowExecutionRepository from: {class_path}")
try:
repository_class = cls._import_class(class_path)
cls._validate_repository_interface(repository_class, WorkflowExecutionRepository)
cls._validate_constructor_signature(
repository_class, ["session_factory", "user", "app_id", "triggered_from"]
)
return repository_class( # type: ignore[no-any-return]
session_factory=session_factory,
user=user,
app_id=app_id,
triggered_from=triggered_from,
)
except RepositoryImportError:
# Re-raise our custom errors as-is
raise
except Exception as e:
logger.exception("Failed to create WorkflowExecutionRepository")
raise RepositoryImportError(f"Failed to create WorkflowExecutionRepository from '{class_path}': {e}") from e
@classmethod
def create_workflow_node_execution_repository(
cls,
session_factory: Union[sessionmaker, Engine],
user: Union[Account, EndUser],
app_id: str,
triggered_from: WorkflowNodeExecutionTriggeredFrom,
) -> WorkflowNodeExecutionRepository:
"""
Create a WorkflowNodeExecutionRepository instance based on configuration.
Args:
session_factory: SQLAlchemy sessionmaker or engine
user: Account or EndUser object
app_id: Application ID
triggered_from: Source of the execution trigger
Returns:
Configured WorkflowNodeExecutionRepository instance
Raises:
RepositoryImportError: If the configured repository cannot be created
"""
class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY
logger.debug(f"Creating WorkflowNodeExecutionRepository from: {class_path}")
try:
repository_class = cls._import_class(class_path)
cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository)
cls._validate_constructor_signature(
repository_class, ["session_factory", "user", "app_id", "triggered_from"]
)
return repository_class( # type: ignore[no-any-return]
session_factory=session_factory,
user=user,
app_id=app_id,
triggered_from=triggered_from,
)
except RepositoryImportError:
# Re-raise our custom errors as-is
raise
except Exception as e:
logger.exception("Failed to create WorkflowNodeExecutionRepository")
raise RepositoryImportError(
f"Failed to create WorkflowNodeExecutionRepository from '{class_path}': {e}"
) from e

View File

@@ -50,7 +50,6 @@ class AppMode(StrEnum):
CHAT = "chat"
ADVANCED_CHAT = "advanced-chat"
AGENT_CHAT = "agent-chat"
CHANNEL = "channel"
@classmethod
def value_of(cls, value: str) -> "AppMode":
@@ -934,7 +933,7 @@ class Message(Base):
created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
workflow_run_id = db.Column(StringUUID)
workflow_run_id: Mapped[str] = db.Column(StringUUID)
@property
def inputs(self):

View File

View File

@@ -0,0 +1,197 @@
"""
Service-layer repository protocol for WorkflowNodeExecutionModel operations.
This module provides a protocol interface for service-layer operations on WorkflowNodeExecutionModel
that abstracts database queries currently done directly in service classes. This repository is
specifically designed for service-layer needs and is separate from the core domain repository.
The service repository handles operations that require access to database-specific fields like
tenant_id, app_id, triggered_from, etc., which are not part of the core domain model.
"""
from collections.abc import Sequence
from datetime import datetime
from typing import Optional, Protocol
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from models.workflow import WorkflowNodeExecutionModel
class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol):
"""
Protocol for service-layer operations on WorkflowNodeExecutionModel.
This repository provides database access patterns specifically needed by service classes,
handling queries that involve database-specific fields and multi-tenancy concerns.
Key responsibilities:
- Manages database operations for workflow node executions
- Handles multi-tenant data isolation
- Provides batch processing capabilities
- Supports execution lifecycle management
Implementation notes:
- Returns database models directly (WorkflowNodeExecutionModel)
- Handles tenant/app filtering automatically
- Provides service-specific query patterns
- Focuses on database operations without domain logic
- Supports cleanup and maintenance operations
"""
def get_node_last_execution(
self,
tenant_id: str,
app_id: str,
workflow_id: str,
node_id: str,
) -> Optional[WorkflowNodeExecutionModel]:
"""
Get the most recent execution for a specific node.
This method finds the latest execution of a specific node within a workflow,
ordered by creation time. Used primarily for debugging and inspection purposes.
Args:
tenant_id: The tenant identifier
app_id: The application identifier
workflow_id: The workflow identifier
node_id: The node identifier
Returns:
The most recent WorkflowNodeExecutionModel for the node, or None if not found
"""
...
def get_executions_by_workflow_run(
self,
tenant_id: str,
app_id: str,
workflow_run_id: str,
) -> Sequence[WorkflowNodeExecutionModel]:
"""
Get all node executions for a specific workflow run.
This method retrieves all node executions that belong to a specific workflow run,
ordered by index in descending order for proper trace visualization.
Args:
tenant_id: The tenant identifier
app_id: The application identifier
workflow_run_id: The workflow run identifier
Returns:
A sequence of WorkflowNodeExecutionModel instances ordered by index (desc)
"""
...
def get_execution_by_id(
self,
execution_id: str,
tenant_id: Optional[str] = None,
) -> Optional[WorkflowNodeExecutionModel]:
"""
Get a workflow node execution by its ID.
This method retrieves a specific execution by its unique identifier.
Tenant filtering is optional for cases where the execution ID is globally unique.
When `tenant_id` is None, it's the caller's responsibility to ensure proper data isolation between tenants.
If the `execution_id` comes from untrusted sources (e.g., retrieved from an API request), the caller should
set `tenant_id` to prevent horizontal privilege escalation.
Args:
execution_id: The execution identifier
tenant_id: Optional tenant identifier for additional filtering
Returns:
The WorkflowNodeExecutionModel if found, or None if not found
"""
...
def delete_expired_executions(
self,
tenant_id: str,
before_date: datetime,
batch_size: int = 1000,
) -> int:
"""
Delete workflow node executions that are older than the specified date.
This method is used for cleanup operations to remove expired executions
in batches to avoid overwhelming the database.
Args:
tenant_id: The tenant identifier
before_date: Delete executions created before this date
batch_size: Maximum number of executions to delete in one batch
Returns:
The number of executions deleted
"""
...
def delete_executions_by_app(
self,
tenant_id: str,
app_id: str,
batch_size: int = 1000,
) -> int:
"""
Delete all workflow node executions for a specific app.
This method is used when removing an app and all its related data.
Executions are deleted in batches to avoid overwhelming the database.
Args:
tenant_id: The tenant identifier
app_id: The application identifier
batch_size: Maximum number of executions to delete in one batch
Returns:
The total number of executions deleted
"""
...
def get_expired_executions_batch(
self,
tenant_id: str,
before_date: datetime,
batch_size: int = 1000,
) -> Sequence[WorkflowNodeExecutionModel]:
"""
Get a batch of expired workflow node executions for backup purposes.
This method retrieves expired executions without deleting them,
allowing the caller to backup the data before deletion.
Args:
tenant_id: The tenant identifier
before_date: Get executions created before this date
batch_size: Maximum number of executions to retrieve
Returns:
A sequence of WorkflowNodeExecutionModel instances
"""
...
def delete_executions_by_ids(
self,
execution_ids: Sequence[str],
) -> int:
"""
Delete workflow node executions by their IDs.
This method deletes specific executions by their IDs,
typically used after backing up the data.
This method does not perform tenant isolation checks. The caller is responsible for ensuring proper
data isolation between tenants. When execution IDs come from untrusted sources (e.g., API requests),
additional tenant validation should be implemented to prevent unauthorized access.
Args:
execution_ids: List of execution IDs to delete
Returns:
The number of executions deleted
"""
...

View File

@@ -0,0 +1,181 @@
"""
API WorkflowRun Repository Protocol
This module defines the protocol for service-layer WorkflowRun operations.
The repository provides an abstraction layer for WorkflowRun database operations
used by service classes, separating service-layer concerns from core domain logic.
Key Features:
- Paginated workflow run queries with filtering
- Bulk deletion operations with OSS backup support
- Multi-tenant data isolation
- Expired record cleanup with data retention
- Service-layer specific query patterns
Usage:
This protocol should be used by service classes that need to perform
WorkflowRun database operations. It provides a clean interface that
hides implementation details and supports dependency injection.
Example:
```python
from repositories.dify_api_repository_factory import DifyAPIRepositoryFactory
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
# Get paginated workflow runs
runs = repo.get_paginated_workflow_runs(
tenant_id="tenant-123",
app_id="app-456",
triggered_from="debugging",
limit=20
)
```
"""
from collections.abc import Sequence
from datetime import datetime
from typing import Optional, Protocol
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.workflow import WorkflowRun
class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
"""
Protocol for service-layer WorkflowRun repository operations.
This protocol defines the interface for WorkflowRun database operations
that are specific to service-layer needs, including pagination, filtering,
and bulk operations with data backup support.
"""
def get_paginated_workflow_runs(
self,
tenant_id: str,
app_id: str,
triggered_from: str,
limit: int = 20,
last_id: Optional[str] = None,
) -> InfiniteScrollPagination:
"""
Get paginated workflow runs with filtering.
Retrieves workflow runs for a specific app and trigger source with
cursor-based pagination support. Used primarily for debugging and
workflow run listing in the UI.
Args:
tenant_id: Tenant identifier for multi-tenant isolation
app_id: Application identifier
triggered_from: Filter by trigger source (e.g., "debugging", "app-run")
limit: Maximum number of records to return (default: 20)
last_id: Cursor for pagination - ID of the last record from previous page
Returns:
InfiniteScrollPagination object containing:
- data: List of WorkflowRun objects
- limit: Applied limit
- has_more: Boolean indicating if more records exist
Raises:
ValueError: If last_id is provided but the corresponding record doesn't exist
"""
...
def get_workflow_run_by_id(
self,
tenant_id: str,
app_id: str,
run_id: str,
) -> Optional[WorkflowRun]:
"""
Get a specific workflow run by ID.
Retrieves a single workflow run with tenant and app isolation.
Used for workflow run detail views and execution tracking.
Args:
tenant_id: Tenant identifier for multi-tenant isolation
app_id: Application identifier
run_id: Workflow run identifier
Returns:
WorkflowRun object if found, None otherwise
"""
...
def get_expired_runs_batch(
self,
tenant_id: str,
before_date: datetime,
batch_size: int = 1000,
) -> Sequence[WorkflowRun]:
"""
Get a batch of expired workflow runs for cleanup.
Retrieves workflow runs created before the specified date for
cleanup operations. Used by scheduled tasks to remove old data
while maintaining data retention policies.
Args:
tenant_id: Tenant identifier for multi-tenant isolation
before_date: Only return runs created before this date
batch_size: Maximum number of records to return
Returns:
Sequence of WorkflowRun objects to be processed for cleanup
"""
...
def delete_runs_by_ids(
self,
run_ids: Sequence[str],
) -> int:
"""
Delete workflow runs by their IDs.
Performs bulk deletion of workflow runs by ID. This method should
be used after backing up the data to OSS storage for retention.
Args:
run_ids: Sequence of workflow run IDs to delete
Returns:
Number of records actually deleted
Note:
This method performs hard deletion. Ensure data is backed up
to OSS storage before calling this method for compliance with
data retention policies.
"""
...
def delete_runs_by_app(
self,
tenant_id: str,
app_id: str,
batch_size: int = 1000,
) -> int:
"""
Delete all workflow runs for a specific app.
Performs bulk deletion of all workflow runs associated with an app.
Used during app cleanup operations. Processes records in batches
to avoid memory issues and long-running transactions.
Args:
tenant_id: Tenant identifier for multi-tenant isolation
app_id: Application identifier
batch_size: Number of records to process in each batch
Returns:
Total number of records deleted across all batches
Note:
This method performs hard deletion without backup. Use with caution
and ensure proper data retention policies are followed.
"""
...

103
api/repositories/factory.py Normal file
View File

@@ -0,0 +1,103 @@
"""
DifyAPI Repository Factory for creating repository instances.
This factory is specifically designed for DifyAPI repositories that handle
service-layer operations with dependency injection patterns.
"""
import logging
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.repositories import DifyCoreRepositoryFactory, RepositoryImportError
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
logger = logging.getLogger(__name__)
class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
"""
Factory for creating DifyAPI repository instances based on configuration.
This factory handles the creation of repositories that are specifically designed
for service-layer operations and use dependency injection with sessionmaker
for better testability and separation of concerns.
"""
@classmethod
def create_api_workflow_node_execution_repository(
cls, session_maker: sessionmaker
) -> DifyAPIWorkflowNodeExecutionRepository:
"""
Create a DifyAPIWorkflowNodeExecutionRepository instance based on configuration.
This repository is designed for service-layer operations and uses dependency injection
with a sessionmaker for better testability and separation of concerns. It provides
database access patterns specifically needed by service classes, handling queries
that involve database-specific fields and multi-tenancy concerns.
Args:
session_maker: SQLAlchemy sessionmaker to inject for database session management.
Returns:
Configured DifyAPIWorkflowNodeExecutionRepository instance
Raises:
RepositoryImportError: If the configured repository cannot be imported or instantiated
"""
class_path = dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY
logger.debug(f"Creating DifyAPIWorkflowNodeExecutionRepository from: {class_path}")
try:
repository_class = cls._import_class(class_path)
cls._validate_repository_interface(repository_class, DifyAPIWorkflowNodeExecutionRepository)
# Service repository requires session_maker parameter
cls._validate_constructor_signature(repository_class, ["session_maker"])
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
except RepositoryImportError:
# Re-raise our custom errors as-is
raise
except Exception as e:
logger.exception("Failed to create DifyAPIWorkflowNodeExecutionRepository")
raise RepositoryImportError(
f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}"
) from e
@classmethod
def create_api_workflow_run_repository(cls, session_maker: sessionmaker) -> APIWorkflowRunRepository:
"""
Create an APIWorkflowRunRepository instance based on configuration.
This repository is designed for service-layer WorkflowRun operations and uses dependency
injection with a sessionmaker for better testability and separation of concerns. It provides
database access patterns specifically needed by service classes for workflow run management,
including pagination, filtering, and bulk operations.
Args:
session_maker: SQLAlchemy sessionmaker to inject for database session management.
Returns:
Configured APIWorkflowRunRepository instance
Raises:
RepositoryImportError: If the configured repository cannot be imported or instantiated
"""
class_path = dify_config.API_WORKFLOW_RUN_REPOSITORY
logger.debug(f"Creating APIWorkflowRunRepository from: {class_path}")
try:
repository_class = cls._import_class(class_path)
cls._validate_repository_interface(repository_class, APIWorkflowRunRepository)
# Service repository requires session_maker parameter
cls._validate_constructor_signature(repository_class, ["session_maker"])
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
except RepositoryImportError:
# Re-raise our custom errors as-is
raise
except Exception as e:
logger.exception("Failed to create APIWorkflowRunRepository")
raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e

View File

@@ -0,0 +1,290 @@
"""
SQLAlchemy implementation of WorkflowNodeExecutionServiceRepository.
This module provides a concrete implementation of the service repository protocol
using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
"""
from collections.abc import Sequence
from datetime import datetime
from typing import Optional
from sqlalchemy import delete, desc, select
from sqlalchemy.orm import Session, sessionmaker
from models.workflow import WorkflowNodeExecutionModel
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
"""
SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository.
This repository provides service-layer database operations for WorkflowNodeExecutionModel
using SQLAlchemy 2.0 style queries. It implements the DifyAPIWorkflowNodeExecutionRepository
protocol with the following features:
- Multi-tenancy data isolation through tenant_id filtering
- Direct database model operations without domain conversion
- Batch processing for efficient large-scale operations
- Optimized query patterns for common access patterns
- Dependency injection for better testability and maintainability
- Session management and transaction handling with proper cleanup
- Maintenance operations for data lifecycle management
- Thread-safe database operations using session-per-request pattern
"""
def __init__(self, session_maker: sessionmaker[Session]):
"""
Initialize the repository with a sessionmaker.
Args:
session_maker: SQLAlchemy sessionmaker for creating database sessions
"""
self._session_maker = session_maker
def get_node_last_execution(
self,
tenant_id: str,
app_id: str,
workflow_id: str,
node_id: str,
) -> Optional[WorkflowNodeExecutionModel]:
"""
Get the most recent execution for a specific node.
This method replicates the query pattern from WorkflowService.get_node_last_run()
using SQLAlchemy 2.0 style syntax.
Args:
tenant_id: The tenant identifier
app_id: The application identifier
workflow_id: The workflow identifier
node_id: The node identifier
Returns:
The most recent WorkflowNodeExecutionModel for the node, or None if not found
"""
stmt = (
select(WorkflowNodeExecutionModel)
.where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.app_id == app_id,
WorkflowNodeExecutionModel.workflow_id == workflow_id,
WorkflowNodeExecutionModel.node_id == node_id,
)
.order_by(desc(WorkflowNodeExecutionModel.created_at))
.limit(1)
)
with self._session_maker() as session:
return session.scalar(stmt)
def get_executions_by_workflow_run(
self,
tenant_id: str,
app_id: str,
workflow_run_id: str,
) -> Sequence[WorkflowNodeExecutionModel]:
"""
Get all node executions for a specific workflow run.
This method replicates the query pattern from WorkflowRunService.get_workflow_run_node_executions()
using SQLAlchemy 2.0 style syntax.
Args:
tenant_id: The tenant identifier
app_id: The application identifier
workflow_run_id: The workflow run identifier
Returns:
A sequence of WorkflowNodeExecutionModel instances ordered by index (desc)
"""
stmt = (
select(WorkflowNodeExecutionModel)
.where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.app_id == app_id,
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
)
.order_by(desc(WorkflowNodeExecutionModel.index))
)
with self._session_maker() as session:
return session.execute(stmt).scalars().all()
def get_execution_by_id(
self,
execution_id: str,
tenant_id: Optional[str] = None,
) -> Optional[WorkflowNodeExecutionModel]:
"""
Get a workflow node execution by its ID.
This method replicates the query pattern from WorkflowDraftVariableService
and WorkflowService.single_step_run_workflow_node() using SQLAlchemy 2.0 style syntax.
When `tenant_id` is None, it's the caller's responsibility to ensure proper data isolation between tenants.
If the `execution_id` comes from untrusted sources (e.g., retrieved from an API request), the caller should
set `tenant_id` to prevent horizontal privilege escalation.
Args:
execution_id: The execution identifier
tenant_id: Optional tenant identifier for additional filtering
Returns:
The WorkflowNodeExecutionModel if found, or None if not found
"""
stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution_id)
# Add tenant filtering if provided
if tenant_id is not None:
stmt = stmt.where(WorkflowNodeExecutionModel.tenant_id == tenant_id)
with self._session_maker() as session:
return session.scalar(stmt)
def delete_expired_executions(
self,
tenant_id: str,
before_date: datetime,
batch_size: int = 1000,
) -> int:
"""
Delete workflow node executions that are older than the specified date.
Args:
tenant_id: The tenant identifier
before_date: Delete executions created before this date
batch_size: Maximum number of executions to delete in one batch
Returns:
The number of executions deleted
"""
total_deleted = 0
while True:
with self._session_maker() as session:
# Find executions to delete in batches
stmt = (
select(WorkflowNodeExecutionModel.id)
.where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.created_at < before_date,
)
.limit(batch_size)
)
execution_ids = session.execute(stmt).scalars().all()
if not execution_ids:
break
# Delete the batch
delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
result = session.execute(delete_stmt)
session.commit()
total_deleted += result.rowcount
# If we deleted fewer than the batch size, we're done
if len(execution_ids) < batch_size:
break
return total_deleted
def delete_executions_by_app(
self,
tenant_id: str,
app_id: str,
batch_size: int = 1000,
) -> int:
"""
Delete all workflow node executions for a specific app.
Args:
tenant_id: The tenant identifier
app_id: The application identifier
batch_size: Maximum number of executions to delete in one batch
Returns:
The total number of executions deleted
"""
total_deleted = 0
while True:
with self._session_maker() as session:
# Find executions to delete in batches
stmt = (
select(WorkflowNodeExecutionModel.id)
.where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.app_id == app_id,
)
.limit(batch_size)
)
execution_ids = session.execute(stmt).scalars().all()
if not execution_ids:
break
# Delete the batch
delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
result = session.execute(delete_stmt)
session.commit()
total_deleted += result.rowcount
# If we deleted fewer than the batch size, we're done
if len(execution_ids) < batch_size:
break
return total_deleted
def get_expired_executions_batch(
self,
tenant_id: str,
before_date: datetime,
batch_size: int = 1000,
) -> Sequence[WorkflowNodeExecutionModel]:
"""
Get a batch of expired workflow node executions for backup purposes.
Args:
tenant_id: The tenant identifier
before_date: Get executions created before this date
batch_size: Maximum number of executions to retrieve
Returns:
A sequence of WorkflowNodeExecutionModel instances
"""
stmt = (
select(WorkflowNodeExecutionModel)
.where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.created_at < before_date,
)
.limit(batch_size)
)
with self._session_maker() as session:
return session.execute(stmt).scalars().all()
def delete_executions_by_ids(
self,
execution_ids: Sequence[str],
) -> int:
"""
Delete workflow node executions by their IDs.
Args:
execution_ids: List of execution IDs to delete
Returns:
The number of executions deleted
"""
if not execution_ids:
return 0
with self._session_maker() as session:
stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
result = session.execute(stmt)
session.commit()
return result.rowcount

View File

@@ -0,0 +1,202 @@
"""
SQLAlchemy API WorkflowRun Repository Implementation
This module provides the SQLAlchemy-based implementation of the APIWorkflowRunRepository
protocol. It handles service-layer WorkflowRun database operations using SQLAlchemy 2.0
style queries with proper session management and multi-tenant data isolation.
Key Features:
- SQLAlchemy 2.0 style queries for modern database operations
- Cursor-based pagination for efficient large dataset handling
- Bulk operations with batch processing for performance
- Multi-tenant data isolation and security
- Proper session management with dependency injection
Implementation Notes:
- Uses sessionmaker for consistent session management
- Implements cursor-based pagination using created_at timestamps
- Provides efficient bulk deletion with batch processing
- Maintains data consistency with proper transaction handling
"""
import logging
from collections.abc import Sequence
from datetime import datetime
from typing import Optional, cast
from sqlalchemy import delete, select
from sqlalchemy.orm import Session, sessionmaker
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.workflow import WorkflowRun
logger = logging.getLogger(__name__)
class DifyAPISQLAlchemyWorkflowRunRepository:
"""
SQLAlchemy implementation of APIWorkflowRunRepository.
Provides service-layer WorkflowRun database operations using SQLAlchemy 2.0
style queries. Supports dependency injection through sessionmaker and
maintains proper multi-tenant data isolation.
Args:
session_maker: SQLAlchemy sessionmaker instance for database connections
"""
def __init__(self, session_maker: sessionmaker[Session]) -> None:
"""
Initialize the repository with a sessionmaker.
Args:
session_maker: SQLAlchemy sessionmaker for database connections
"""
self._session_maker = session_maker
def get_paginated_workflow_runs(
self,
tenant_id: str,
app_id: str,
triggered_from: str,
limit: int = 20,
last_id: Optional[str] = None,
) -> InfiniteScrollPagination:
"""
Get paginated workflow runs with filtering.
Implements cursor-based pagination using created_at timestamps for
efficient handling of large datasets. Filters by tenant, app, and
trigger source for proper data isolation.
"""
with self._session_maker() as session:
# Build base query with filters
base_stmt = select(WorkflowRun).where(
WorkflowRun.tenant_id == tenant_id,
WorkflowRun.app_id == app_id,
WorkflowRun.triggered_from == triggered_from,
)
if last_id:
# Get the last workflow run for cursor-based pagination
last_run_stmt = base_stmt.where(WorkflowRun.id == last_id)
last_workflow_run = session.scalar(last_run_stmt)
if not last_workflow_run:
raise ValueError("Last workflow run not exists")
# Get records created before the last run's timestamp
base_stmt = base_stmt.where(
WorkflowRun.created_at < last_workflow_run.created_at,
WorkflowRun.id != last_workflow_run.id,
)
# First page - get most recent records
workflow_runs = session.scalars(base_stmt.order_by(WorkflowRun.created_at.desc()).limit(limit + 1)).all()
# Check if there are more records for pagination
has_more = len(workflow_runs) > limit
if has_more:
workflow_runs = workflow_runs[:-1]
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
def get_workflow_run_by_id(
self,
tenant_id: str,
app_id: str,
run_id: str,
) -> Optional[WorkflowRun]:
"""
Get a specific workflow run by ID with tenant and app isolation.
"""
with self._session_maker() as session:
stmt = select(WorkflowRun).where(
WorkflowRun.tenant_id == tenant_id,
WorkflowRun.app_id == app_id,
WorkflowRun.id == run_id,
)
return cast(Optional[WorkflowRun], session.scalar(stmt))
def get_expired_runs_batch(
self,
tenant_id: str,
before_date: datetime,
batch_size: int = 1000,
) -> Sequence[WorkflowRun]:
"""
Get a batch of expired workflow runs for cleanup operations.
"""
with self._session_maker() as session:
stmt = (
select(WorkflowRun)
.where(
WorkflowRun.tenant_id == tenant_id,
WorkflowRun.created_at < before_date,
)
.limit(batch_size)
)
return cast(Sequence[WorkflowRun], session.scalars(stmt).all())
def delete_runs_by_ids(
self,
run_ids: Sequence[str],
) -> int:
"""
Delete workflow runs by their IDs using bulk deletion.
"""
if not run_ids:
return 0
with self._session_maker() as session:
stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))
result = session.execute(stmt)
session.commit()
deleted_count = cast(int, result.rowcount)
logger.info(f"Deleted {deleted_count} workflow runs by IDs")
return deleted_count
def delete_runs_by_app(
self,
tenant_id: str,
app_id: str,
batch_size: int = 1000,
) -> int:
"""
Delete all workflow runs for a specific app in batches.
"""
total_deleted = 0
while True:
with self._session_maker() as session:
# Get a batch of run IDs to delete
stmt = (
select(WorkflowRun.id)
.where(
WorkflowRun.tenant_id == tenant_id,
WorkflowRun.app_id == app_id,
)
.limit(batch_size)
)
run_ids = session.scalars(stmt).all()
if not run_ids:
break
# Delete the batch
delete_stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))
result = session.execute(delete_stmt)
session.commit()
batch_deleted = result.rowcount
total_deleted += batch_deleted
logger.info(f"Deleted batch of {batch_deleted} workflow runs for app {app_id}")
# If we deleted fewer records than the batch size, we're done
if batch_deleted < batch_size:
break
logger.info(f"Total deleted {total_deleted} workflow runs for app {app_id}")
return total_deleted

View File

@@ -47,8 +47,6 @@ class AppService:
filters.append(App.mode == AppMode.ADVANCED_CHAT.value)
elif args["mode"] == "agent-chat":
filters.append(App.mode == AppMode.AGENT_CHAT.value)
elif args["mode"] == "channel":
filters.append(App.mode == AppMode.CHANNEL.value)
if args.get("is_created_by_me", False):
filters.append(App.created_by == user_id)

View File

@@ -6,7 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
import click
from flask import Flask, current_app
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -14,7 +14,7 @@ from extensions.ext_database import db
from extensions.ext_storage import storage
from models.account import Tenant
from models.model import App, Conversation, Message
from models.workflow import WorkflowNodeExecutionModel, WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
from services.billing_service import BillingService
logger = logging.getLogger(__name__)
@@ -105,23 +105,24 @@ class ClearFreePlanTenantExpiredLogs:
)
)
# Process expired workflow node executions with backup
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
before_date = datetime.datetime.now() - datetime.timedelta(days=days)
total_deleted = 0
while True:
with Session(db.engine).no_autoflush as session:
workflow_node_executions = (
session.query(WorkflowNodeExecutionModel)
.filter(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.created_at
< datetime.datetime.now() - datetime.timedelta(days=days),
)
.limit(batch)
.all()
# Get a batch of expired executions for backup
workflow_node_executions = node_execution_repo.get_expired_executions_batch(
tenant_id=tenant_id,
before_date=before_date,
batch_size=batch,
)
if len(workflow_node_executions) == 0:
break
# save workflow node executions
# Save workflow node executions to storage
storage.save(
f"free_plan_tenant_expired_logs/"
f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}"
@@ -131,15 +132,14 @@ class ClearFreePlanTenantExpiredLogs:
).encode("utf-8"),
)
# Extract IDs for deletion
workflow_node_execution_ids = [
workflow_node_execution.id for workflow_node_execution in workflow_node_executions
]
# delete workflow node executions
session.query(WorkflowNodeExecutionModel).filter(
WorkflowNodeExecutionModel.id.in_(workflow_node_execution_ids),
).delete(synchronize_session=False)
session.commit()
# Delete the backed up executions
deleted_count = node_execution_repo.delete_executions_by_ids(workflow_node_execution_ids)
total_deleted += deleted_count
click.echo(
click.style(
@@ -148,23 +148,28 @@ class ClearFreePlanTenantExpiredLogs:
)
)
# If we got fewer than the batch size, we're done
if len(workflow_node_executions) < batch:
break
# Process expired workflow runs with backup
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
before_date = datetime.datetime.now() - datetime.timedelta(days=days)
total_deleted = 0
while True:
with Session(db.engine).no_autoflush as session:
workflow_runs = (
session.query(WorkflowRun)
.filter(
WorkflowRun.tenant_id == tenant_id,
WorkflowRun.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
)
.limit(batch)
.all()
# Get a batch of expired workflow runs for backup
workflow_runs = workflow_run_repo.get_expired_runs_batch(
tenant_id=tenant_id,
before_date=before_date,
batch_size=batch,
)
if len(workflow_runs) == 0:
break
# save workflow runs
# Save workflow runs to storage
storage.save(
f"free_plan_tenant_expired_logs/"
f"{tenant_id}/workflow_runs/{datetime.datetime.now().strftime('%Y-%m-%d')}"
@@ -176,13 +181,23 @@ class ClearFreePlanTenantExpiredLogs:
).encode("utf-8"),
)
# Extract IDs for deletion
workflow_run_ids = [workflow_run.id for workflow_run in workflow_runs]
# delete workflow runs
session.query(WorkflowRun).filter(
WorkflowRun.id.in_(workflow_run_ids),
).delete(synchronize_session=False)
session.commit()
# Delete the backed up workflow runs
deleted_count = workflow_run_repo.delete_runs_by_ids(workflow_run_ids)
total_deleted += deleted_count
click.echo(
click.style(
f"[{datetime.datetime.now()}] Processed {len(workflow_run_ids)}"
f" workflow runs for tenant {tenant_id}"
)
)
# If we got fewer than the batch size, we're done
if len(workflow_runs) < batch:
break
@classmethod
def process(cls, days: int, batch: int, tenant_ids: list[str]):

View File

@@ -5,9 +5,9 @@ from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any, ClassVar
from sqlalchemy import Engine, orm, select
from sqlalchemy import Engine, orm
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.sql.expression import and_, or_
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -25,7 +25,8 @@ from factories.file_factory import StorageKeyLoader
from factories.variable_factory import build_segment, segment_to_variable
from models import App, Conversation
from models.enums import DraftVariableType
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
from models.workflow import Workflow, WorkflowDraftVariable, is_system_variable_editable
from repositories.factory import DifyAPIRepositoryFactory
_logger = logging.getLogger(__name__)
@@ -117,7 +118,24 @@ class WorkflowDraftVariableService:
_session: Session
def __init__(self, session: Session) -> None:
"""
Initialize the WorkflowDraftVariableService with a SQLAlchemy session.
Args:
session (Session): The SQLAlchemy session used to execute database queries.
The provided session must be bound to an `Engine` object, not a specific `Connection`.
Raises:
AssertionError: If the provided session is not bound to an `Engine` object.
"""
self._session = session
engine = session.get_bind()
# Ensure the session is bound to a engine.
assert isinstance(engine, Engine)
session_maker = sessionmaker(bind=engine, expire_on_commit=False)
self._api_node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
session_maker
)
def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None:
return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first()
@@ -248,8 +266,7 @@ class WorkflowDraftVariableService:
_logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name)
return None
query = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == variable.node_execution_id)
node_exec = self._session.scalars(query).first()
node_exec = self._api_node_execution_repo.get_execution_by_id(variable.node_execution_id)
if node_exec is None:
_logger.warning(
"Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s",
@@ -298,6 +315,8 @@ class WorkflowDraftVariableService:
def reset_variable(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None:
variable_type = variable.get_variable_type()
if variable_type == DraftVariableType.SYS and not is_system_variable_editable(variable.name):
raise VariableResetError(f"cannot reset system variable, variable_id={variable.id}")
if variable_type == DraftVariableType.CONVERSATION:
return self._reset_conv_var(workflow, variable)
else:

View File

@@ -2,9 +2,9 @@ import threading
from collections.abc import Sequence
from typing import Optional
from sqlalchemy.orm import sessionmaker
import contexts
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import (
@@ -15,10 +15,18 @@ from models import (
WorkflowRun,
WorkflowRunTriggeredFrom,
)
from models.workflow import WorkflowNodeExecutionTriggeredFrom
from repositories.factory import DifyAPIRepositoryFactory
class WorkflowRunService:
def __init__(self):
"""Initialize WorkflowRunService with repository dependencies."""
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
session_maker
)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
def get_paginate_advanced_chat_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination:
"""
Get advanced chat app workflow run list
@@ -62,45 +70,16 @@ class WorkflowRunService:
:param args: request args
"""
limit = int(args.get("limit", 20))
last_id = args.get("last_id")
base_query = db.session.query(WorkflowRun).filter(
WorkflowRun.tenant_id == app_model.tenant_id,
WorkflowRun.app_id == app_model.id,
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value,
return self._workflow_run_repo.get_paginated_workflow_runs(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value,
limit=limit,
last_id=last_id,
)
if args.get("last_id"):
last_workflow_run = base_query.filter(
WorkflowRun.id == args.get("last_id"),
).first()
if not last_workflow_run:
raise ValueError("Last workflow run not exists")
workflow_runs = (
base_query.filter(
WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
)
.order_by(WorkflowRun.created_at.desc())
.limit(limit)
.all()
)
else:
workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
has_more = False
if len(workflow_runs) == limit:
current_page_first_workflow_run = workflow_runs[-1]
rest_count = base_query.filter(
WorkflowRun.created_at < current_page_first_workflow_run.created_at,
WorkflowRun.id != current_page_first_workflow_run.id,
).count()
if rest_count > 0:
has_more = True
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]:
"""
Get workflow run detail
@@ -108,17 +87,11 @@ class WorkflowRunService:
:param app_model: app model
:param run_id: workflow run id
"""
workflow_run = (
db.session.query(WorkflowRun)
.filter(
WorkflowRun.tenant_id == app_model.tenant_id,
WorkflowRun.app_id == app_model.id,
WorkflowRun.id == run_id,
return self._workflow_run_repo.get_workflow_run_by_id(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
run_id=run_id,
)
.first()
)
return workflow_run
def get_workflow_run_node_executions(
self,
@@ -137,17 +110,13 @@ class WorkflowRunService:
if not workflow_run:
return []
repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=db.engine,
user=user,
# Get tenant_id from user
tenant_id = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id
if tenant_id is None:
raise ValueError("User tenant_id cannot be None")
return self._node_execution_service_repo.get_executions_by_workflow_run(
tenant_id=tenant_id,
app_id=app_model.id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
workflow_run_id=run_id,
)
# Use the repository to get the database models directly
order_config = OrderConfig(order_by=["index"], order_direction="desc")
workflow_node_executions = repository.get_db_models_by_workflow_run(
workflow_run_id=run_id, order_config=order_config
)
return workflow_node_executions

View File

@@ -7,13 +7,13 @@ from typing import Any, Optional
from uuid import uuid4
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from core.app.app_config.entities import VariableEntityType
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
@@ -41,6 +41,7 @@ from models.workflow import (
WorkflowNodeExecutionTriggeredFrom,
WorkflowType,
)
from repositories.factory import DifyAPIRepositoryFactory
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError
from services.workflow.workflow_converter import WorkflowConverter
@@ -57,21 +58,32 @@ class WorkflowService:
Workflow Service
"""
def __init__(self, session_maker: sessionmaker | None = None):
"""Initialize WorkflowService with repository dependencies."""
if session_maker is None:
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
session_maker
)
def get_node_last_run(self, app_model: App, workflow: Workflow, node_id: str) -> WorkflowNodeExecutionModel | None:
# TODO(QuantumGhost): This query is not fully covered by index.
criteria = (
WorkflowNodeExecutionModel.tenant_id == app_model.tenant_id,
WorkflowNodeExecutionModel.app_id == app_model.id,
WorkflowNodeExecutionModel.workflow_id == workflow.id,
WorkflowNodeExecutionModel.node_id == node_id,
"""
Get the most recent execution for a specific node.
Args:
app_model: The application model
workflow: The workflow model
node_id: The node identifier
Returns:
The most recent WorkflowNodeExecutionModel for the node, or None if not found
"""
return self._node_execution_service_repo.get_node_last_execution(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
workflow_id=workflow.id,
node_id=node_id,
)
node_exec = (
db.session.query(WorkflowNodeExecutionModel)
.filter(*criteria)
.order_by(WorkflowNodeExecutionModel.created_at.desc())
.first()
)
return node_exec
def is_workflow_exist(self, app_model: App) -> bool:
return (
@@ -396,7 +408,7 @@ class WorkflowService:
node_execution.workflow_id = draft_workflow.id
# Create repository and save the node execution
repository = SQLAlchemyWorkflowNodeExecutionRepository(
repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=db.engine,
user=account,
app_id=app_model.id,
@@ -404,8 +416,9 @@ class WorkflowService:
)
repository.save(node_execution)
# Convert node_execution to WorkflowNodeExecution after save
workflow_node_execution = repository.to_db_model(node_execution)
workflow_node_execution = self._node_execution_service_repo.get_execution_by_id(node_execution.id)
if workflow_node_execution is None:
raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving")
with Session(bind=db.engine) as session, session.begin():
draft_var_saver = DraftVariableSaver(
@@ -418,6 +431,7 @@ class WorkflowService:
)
draft_var_saver.save(process_data=node_execution.process_data, outputs=node_execution.outputs)
session.commit()
return workflow_node_execution
def run_free_workflow_node(
@@ -429,7 +443,7 @@ class WorkflowService:
# run draft workflow node
start_at = time.perf_counter()
workflow_node_execution = self._handle_node_run_result(
node_execution = self._handle_node_run_result(
invoke_node_fn=lambda: WorkflowEntry.run_free_node(
node_id=node_id,
node_data=node_data,
@@ -441,7 +455,7 @@ class WorkflowService:
node_id=node_id,
)
return workflow_node_execution
return node_execution
def _handle_node_run_result(
self,

View File

@@ -6,6 +6,7 @@ import click
from celery import shared_task # type: ignore
from sqlalchemy import delete
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
from models import (
@@ -31,7 +32,8 @@ from models import (
)
from models.tools import WorkflowToolProvider
from models.web import PinnedConversation, SavedMessage
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecutionModel, WorkflowRun
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog
from repositories.factory import DifyAPIRepositoryFactory
@shared_task(queue="app_deletion", bind=True, max_retries=3)
@@ -189,30 +191,32 @@ def _delete_app_workflows(tenant_id: str, app_id: str):
def _delete_app_workflow_runs(tenant_id: str, app_id: str):
def del_workflow_run(workflow_run_id: str):
db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).delete(synchronize_session=False)
"""Delete all workflow runs for an app using the service repository."""
session_maker = sessionmaker(bind=db.engine)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
_delete_records(
"""select id from workflow_runs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
{"tenant_id": tenant_id, "app_id": app_id},
del_workflow_run,
"workflow run",
deleted_count = workflow_run_repo.delete_runs_by_app(
tenant_id=tenant_id,
app_id=app_id,
batch_size=1000,
)
logging.info(f"Deleted {deleted_count} workflow runs for app {app_id}")
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
def del_workflow_node_execution(workflow_node_execution_id: str):
db.session.query(WorkflowNodeExecutionModel).filter(
WorkflowNodeExecutionModel.id == workflow_node_execution_id
).delete(synchronize_session=False)
"""Delete all workflow node executions for an app using the service repository."""
session_maker = sessionmaker(bind=db.engine)
node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
_delete_records(
"""select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
{"tenant_id": tenant_id, "app_id": app_id},
del_workflow_node_execution,
"workflow node execution",
deleted_count = node_execution_repo.delete_executions_by_app(
tenant_id=tenant_id,
app_id=app_id,
batch_size=1000,
)
logging.info(f"Deleted {deleted_count} workflow node executions for app {app_id}")
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def del_workflow_app_log(workflow_app_log_id: str):

View File

@@ -0,0 +1 @@
# Unit tests for core repositories module

View File

@@ -0,0 +1,455 @@
"""
Unit tests for the RepositoryFactory.
This module tests the factory pattern implementation for creating repository instances
based on configuration, including error handling and validation.
"""
from unittest.mock import MagicMock, patch
import pytest
from pytest_mock import MockerFixture
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from models import Account, EndUser
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowNodeExecutionTriggeredFrom
class TestRepositoryFactory:
"""Test cases for RepositoryFactory."""
def test_import_class_success(self):
"""Test successful class import."""
# Test importing a real class
class_path = "unittest.mock.MagicMock"
result = DifyCoreRepositoryFactory._import_class(class_path)
assert result is MagicMock
def test_import_class_invalid_path(self):
"""Test import with invalid module path."""
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._import_class("invalid.module.path")
assert "Cannot import repository class" in str(exc_info.value)
def test_import_class_invalid_class_name(self):
"""Test import with invalid class name."""
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._import_class("unittest.mock.NonExistentClass")
assert "Cannot import repository class" in str(exc_info.value)
def test_import_class_malformed_path(self):
"""Test import with malformed path (no dots)."""
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._import_class("invalidpath")
assert "Cannot import repository class" in str(exc_info.value)
def test_validate_repository_interface_success(self):
"""Test successful interface validation."""
# Create a mock class that implements the required methods
class MockRepository:
def save(self):
pass
def get_by_id(self):
pass
# Create a mock interface with the same methods
class MockInterface:
def save(self):
pass
def get_by_id(self):
pass
# Should not raise an exception
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
def test_validate_repository_interface_missing_methods(self):
"""Test interface validation with missing methods."""
# Create a mock class that doesn't implement all required methods
class IncompleteRepository:
def save(self):
pass
# Missing get_by_id method
# Create a mock interface with required methods
class MockInterface:
def save(self):
pass
def get_by_id(self):
pass
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface)
assert "does not implement required methods" in str(exc_info.value)
assert "get_by_id" in str(exc_info.value)
def test_validate_constructor_signature_success(self):
"""Test successful constructor signature validation."""
class MockRepository:
def __init__(self, session_factory, user, app_id, triggered_from):
pass
# Should not raise an exception
DifyCoreRepositoryFactory._validate_constructor_signature(
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
def test_validate_constructor_signature_missing_params(self):
"""Test constructor validation with missing parameters."""
class IncompleteRepository:
def __init__(self, session_factory, user):
# Missing app_id and triggered_from parameters
pass
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_constructor_signature(
IncompleteRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
assert "does not accept required parameters" in str(exc_info.value)
assert "app_id" in str(exc_info.value)
assert "triggered_from" in str(exc_info.value)
def test_validate_constructor_signature_inspection_error(self, mocker: MockerFixture):
"""Test constructor validation when inspection fails."""
# Mock inspect.signature to raise an exception
mocker.patch("inspect.signature", side_effect=Exception("Inspection failed"))
class MockRepository:
def __init__(self, session_factory):
pass
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"])
assert "Failed to validate constructor signature" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_execution_repository_success(self, mock_config, mocker: MockerFixture):
"""Test successful creation of WorkflowExecutionRepository."""
# Setup mock configuration
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
# Create mock dependencies
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=Account)
app_id = "test-app-id"
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
# Mock the imported class to be a valid repository
mock_repository_class = MagicMock()
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
# Mock the validation methods
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id=app_id,
triggered_from=triggered_from,
)
# Verify the repository was created with correct parameters
mock_repository_class.assert_called_once_with(
session_factory=mock_session_factory,
user=mock_user,
app_id=app_id,
triggered_from=triggered_from,
)
assert result is mock_repository_instance
@patch("core.repositories.factory.dify_config")
def test_create_workflow_execution_repository_import_error(self, mock_config):
"""Test WorkflowExecutionRepository creation with import error."""
# Setup mock configuration with invalid class path
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=Account)
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert "Cannot import repository class" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_execution_repository_validation_error(self, mock_config, mocker: MockerFixture):
"""Test WorkflowExecutionRepository creation with validation error."""
# Setup mock configuration
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=Account)
# Mock import to succeed but validation to fail
mock_repository_class = MagicMock()
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(
DifyCoreRepositoryFactory,
"_validate_repository_interface",
side_effect=RepositoryImportError("Interface validation failed"),
),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert "Interface validation failed" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_execution_repository_instantiation_error(self, mock_config, mocker: MockerFixture):
"""Test WorkflowExecutionRepository creation with instantiation error."""
# Setup mock configuration
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=Account)
# Mock import and validation to succeed but instantiation to fail
mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_success(self, mock_config, mocker: MockerFixture):
"""Test successful creation of WorkflowNodeExecutionRepository."""
# Setup mock configuration
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
# Create mock dependencies
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
app_id = "test-app-id"
triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
# Mock the imported class to be a valid repository
mock_repository_class = MagicMock()
mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
# Mock the validation methods
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id=app_id,
triggered_from=triggered_from,
)
# Verify the repository was created with correct parameters
mock_repository_class.assert_called_once_with(
session_factory=mock_session_factory,
user=mock_user,
app_id=app_id,
triggered_from=triggered_from,
)
assert result is mock_repository_instance
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_import_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with import error."""
# Setup mock configuration with invalid class path
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert "Cannot import repository class" in str(exc_info.value)
def test_repository_import_error_exception(self):
"""Test RepositoryImportError exception."""
error_message = "Test error message"
exception = RepositoryImportError(error_message)
assert str(exception) == error_message
assert isinstance(exception, Exception)
@patch("core.repositories.factory.dify_config")
def test_create_with_engine_instead_of_sessionmaker(self, mock_config, mocker: MockerFixture):
"""Test repository creation with Engine instead of sessionmaker."""
# Setup mock configuration
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
# Create mock dependencies with Engine instead of sessionmaker
mock_engine = MagicMock(spec=Engine)
mock_user = MagicMock(spec=Account)
# Mock the imported class to be a valid repository
mock_repository_class = MagicMock()
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
# Mock the validation methods
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_engine, # Using Engine instead of sessionmaker
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Verify the repository was created with the Engine
mock_repository_class.assert_called_once_with(
session_factory=mock_engine,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert result is mock_repository_instance
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_validation_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with validation error."""
# Setup mock configuration
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
# Mock import to succeed but validation to fail
mock_repository_class = MagicMock()
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(
DifyCoreRepositoryFactory,
"_validate_repository_interface",
side_effect=RepositoryImportError("Interface validation failed"),
),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert "Interface validation failed" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with instantiation error."""
# Setup mock configuration
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
# Mock import and validation to succeed but instantiation to fail
mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
def test_validate_repository_interface_with_private_methods(self):
"""Test interface validation ignores private methods."""
# Create a mock class with private methods
class MockRepository:
def save(self):
pass
def get_by_id(self):
pass
def _private_method(self):
pass
# Create a mock interface with private methods
class MockInterface:
def save(self):
pass
def get_by_id(self):
pass
def _private_method(self):
pass
# Should not raise an exception (private methods are ignored)
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
def test_validate_constructor_signature_with_extra_params(self):
"""Test constructor validation with extra parameters (should pass)."""
class MockRepository:
def __init__(self, session_factory, user, app_id, triggered_from, extra_param=None):
pass
# Should not raise an exception (extra parameters are allowed)
DifyCoreRepositoryFactory._validate_constructor_signature(
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
def test_validate_constructor_signature_with_kwargs(self):
"""Test constructor validation with **kwargs (current implementation doesn't support this)."""
class MockRepository:
def __init__(self, session_factory, user, **kwargs):
pass
# Current implementation doesn't handle **kwargs, so this should raise an exception
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_constructor_signature(
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
assert "does not accept required parameters" in str(exc_info.value)
assert "app_id" in str(exc_info.value)
assert "triggered_from" in str(exc_info.value)

View File

@@ -10,7 +10,8 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
@pytest.fixture
def workflow_setup():
workflow_service = WorkflowService()
mock_session_maker = MagicMock()
workflow_service = WorkflowService(mock_session_maker)
session = MagicMock(spec=Session)
tenant_id = "test-tenant-id"
workflow_id = "test-workflow-id"

View File

@@ -1,14 +1,14 @@
import dataclasses
import secrets
from unittest import mock
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
import pytest
from sqlalchemy import Engine
from sqlalchemy.orm import Session
from core.variables import StringSegment
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes import NodeType
from core.workflow.nodes.enums import NodeType
from models.enums import DraftVariableType
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
from services.workflow_draft_variable_service import (
@@ -18,13 +18,25 @@ from services.workflow_draft_variable_service import (
)
@pytest.fixture
def mock_engine() -> Engine:
return Mock(spec=Engine)
@pytest.fixture
def mock_session(mock_engine) -> Session:
mock_session = Mock(spec=Session)
mock_session.get_bind.return_value = mock_engine
return mock_session
class TestDraftVariableSaver:
def _get_test_app_id(self):
suffix = secrets.token_hex(6)
return f"test_app_id_{suffix}"
def test__should_variable_be_visible(self):
mock_session = mock.MagicMock(spec=Session)
mock_session = MagicMock(spec=Session)
test_app_id = self._get_test_app_id()
saver = DraftVariableSaver(
session=mock_session,
@@ -70,7 +82,7 @@ class TestDraftVariableSaver:
),
]
mock_session = mock.MagicMock(spec=Session)
mock_session = MagicMock(spec=Session)
test_app_id = self._get_test_app_id()
saver = DraftVariableSaver(
session=mock_session,
@@ -105,9 +117,8 @@ class TestWorkflowDraftVariableService:
conversation_variables=[],
)
def test_reset_conversation_variable(self):
def test_reset_conversation_variable(self, mock_session):
"""Test resetting a conversation variable"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
@@ -131,9 +142,8 @@ class TestWorkflowDraftVariableService:
mock_reset_conv.assert_called_once_with(workflow, variable)
assert result == expected_result
def test_reset_node_variable_with_no_execution_id(self):
def test_reset_node_variable_with_no_execution_id(self, mock_session):
"""Test resetting a node variable with no execution ID - should delete variable"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
@@ -158,11 +168,26 @@ class TestWorkflowDraftVariableService:
mock_session.flush.assert_called_once()
assert result is None
def test_reset_node_variable_with_missing_execution_record(self):
def test_reset_node_variable_with_missing_execution_record(
self,
mock_engine,
mock_session,
monkeypatch,
):
"""Test resetting a node variable when execution record doesn't exist"""
mock_session = Mock(spec=Session)
mock_repo_session = Mock(spec=Session)
mock_session_maker = MagicMock()
# Mock the context manager protocol for sessionmaker
mock_session_maker.return_value.__enter__.return_value = mock_repo_session
mock_session_maker.return_value.__exit__.return_value = None
monkeypatch.setattr("services.workflow_draft_variable_service.sessionmaker", mock_session_maker)
service = WorkflowDraftVariableService(mock_session)
# Mock the repository to return None (no execution record found)
service._api_node_execution_repo = Mock()
service._api_node_execution_repo.get_execution_by_id.return_value = None
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
@@ -171,24 +196,41 @@ class TestWorkflowDraftVariableService:
variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
)
# Mock session.scalars to return None (no execution record found)
mock_scalars = Mock()
mock_scalars.first.return_value = None
mock_session.scalars.return_value = mock_scalars
# Variable is editable by default from factory method
result = service._reset_node_var_or_sys_var(workflow, variable)
mock_session_maker.assert_called_once_with(bind=mock_engine, expire_on_commit=False)
# Should delete the variable and return None
mock_session.delete.assert_called_once_with(instance=variable)
mock_session.flush.assert_called_once()
assert result is None
def test_reset_node_variable_with_valid_execution_record(self):
def test_reset_node_variable_with_valid_execution_record(
self,
mock_session,
monkeypatch,
):
"""Test resetting a node variable with valid execution record - should restore from execution"""
mock_session = Mock(spec=Session)
mock_repo_session = Mock(spec=Session)
mock_session_maker = MagicMock()
# Mock the context manager protocol for sessionmaker
mock_session_maker.return_value.__enter__.return_value = mock_repo_session
mock_session_maker.return_value.__exit__.return_value = None
mock_session_maker = monkeypatch.setattr(
"services.workflow_draft_variable_service.sessionmaker", mock_session_maker
)
service = WorkflowDraftVariableService(mock_session)
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.outputs_dict = {"test_var": "output_value"}
# Mock the repository to return the execution record
service._api_node_execution_repo = Mock()
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
@@ -197,16 +239,7 @@ class TestWorkflowDraftVariableService:
variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
)
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.process_data_dict = {"test_var": "process_value"}
mock_execution.outputs_dict = {"test_var": "output_value"}
# Mock session.scalars to return the execution record
mock_scalars = Mock()
mock_scalars.first.return_value = mock_execution
mock_session.scalars.return_value = mock_scalars
# Variable is editable by default from factory method
# Mock workflow methods
mock_node_config = {"type": "test_node"}
@@ -224,9 +257,8 @@ class TestWorkflowDraftVariableService:
# Should return the updated variable
assert result == variable
def test_reset_non_editable_system_variable_raises_error(self):
def test_reset_non_editable_system_variable_raises_error(self, mock_session):
"""Test that resetting a non-editable system variable raises an error"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
@@ -242,24 +274,13 @@ class TestWorkflowDraftVariableService:
editable=False, # Non-editable system variable
)
# Mock the service to properly check system variable editability
with patch.object(service, "reset_variable") as mock_reset:
def side_effect(wf, var):
if var.get_variable_type() == DraftVariableType.SYS and not is_system_variable_editable(var.name):
raise VariableResetError(f"cannot reset system variable, variable_id={var.id}")
return var
mock_reset.side_effect = side_effect
with pytest.raises(VariableResetError) as exc_info:
service.reset_variable(workflow, variable)
assert "cannot reset system variable" in str(exc_info.value)
assert f"variable_id={variable.id}" in str(exc_info.value)
def test_reset_editable_system_variable_succeeds(self):
def test_reset_editable_system_variable_succeeds(self, mock_session):
"""Test that resetting an editable system variable succeeds"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
@@ -279,10 +300,9 @@ class TestWorkflowDraftVariableService:
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.outputs_dict = {"sys.files": "[]"}
# Mock session.scalars to return the execution record
mock_scalars = Mock()
mock_scalars.first.return_value = mock_execution
mock_session.scalars.return_value = mock_scalars
# Mock the repository to return the execution record
service._api_node_execution_repo = Mock()
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
result = service._reset_node_var_or_sys_var(workflow, variable)
@@ -291,9 +311,8 @@ class TestWorkflowDraftVariableService:
assert variable.last_edited_at is None
mock_session.flush.assert_called()
def test_reset_query_system_variable_succeeds(self):
def test_reset_query_system_variable_succeeds(self, mock_session):
"""Test that resetting query system variable (another editable one) succeeds"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
@@ -313,10 +332,9 @@ class TestWorkflowDraftVariableService:
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.outputs_dict = {"sys.query": "reset query"}
# Mock session.scalars to return the execution record
mock_scalars = Mock()
mock_scalars.first.return_value = mock_execution
mock_session.scalars.return_value = mock_scalars
# Mock the repository to return the execution record
service._api_node_execution_repo = Mock()
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
result = service._reset_node_var_or_sys_var(workflow, variable)

View File

@@ -0,0 +1,288 @@
from datetime import datetime
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from models.workflow import WorkflowNodeExecutionModel
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
)
class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
@pytest.fixture
def repository(self):
mock_session_maker = MagicMock()
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker=mock_session_maker)
@pytest.fixture
def mock_execution(self):
execution = MagicMock(spec=WorkflowNodeExecutionModel)
execution.id = str(uuid4())
execution.tenant_id = "tenant-123"
execution.app_id = "app-456"
execution.workflow_id = "workflow-789"
execution.workflow_run_id = "run-101"
execution.node_id = "node-202"
execution.index = 1
execution.created_at = "2023-01-01T00:00:00Z"
return execution
def test_get_node_last_execution_found(self, repository, mock_execution):
"""Test getting the last execution for a node when it exists."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = mock_execution
# Act
result = repository.get_node_last_execution(
tenant_id="tenant-123",
app_id="app-456",
workflow_id="workflow-789",
node_id="node-202",
)
# Assert
assert result == mock_execution
mock_session.scalar.assert_called_once()
# Verify the query was constructed correctly
call_args = mock_session.scalar.call_args[0][0]
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
def test_get_node_last_execution_not_found(self, repository):
"""Test getting the last execution for a node when it doesn't exist."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = None
# Act
result = repository.get_node_last_execution(
tenant_id="tenant-123",
app_id="app-456",
workflow_id="workflow-789",
node_id="node-202",
)
# Assert
assert result is None
mock_session.scalar.assert_called_once()
def test_get_executions_by_workflow_run(self, repository, mock_execution):
"""Test getting all executions for a workflow run."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
executions = [mock_execution]
mock_session.execute.return_value.scalars.return_value.all.return_value = executions
# Act
result = repository.get_executions_by_workflow_run(
tenant_id="tenant-123",
app_id="app-456",
workflow_run_id="run-101",
)
# Assert
assert result == executions
mock_session.execute.assert_called_once()
# Verify the query was constructed correctly
call_args = mock_session.execute.call_args[0][0]
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
def test_get_executions_by_workflow_run_empty(self, repository):
"""Test getting executions for a workflow run when none exist."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalars.return_value.all.return_value = []
# Act
result = repository.get_executions_by_workflow_run(
tenant_id="tenant-123",
app_id="app-456",
workflow_run_id="run-101",
)
# Assert
assert result == []
mock_session.execute.assert_called_once()
def test_get_execution_by_id_found(self, repository, mock_execution):
"""Test getting execution by ID when it exists."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = mock_execution
# Act
result = repository.get_execution_by_id(mock_execution.id)
# Assert
assert result == mock_execution
mock_session.scalar.assert_called_once()
def test_get_execution_by_id_not_found(self, repository):
"""Test getting execution by ID when it doesn't exist."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = None
# Act
result = repository.get_execution_by_id("non-existent-id")
# Assert
assert result is None
mock_session.scalar.assert_called_once()
def test_repository_implements_protocol(self, repository):
"""Test that the repository implements the required protocol methods."""
# Verify all protocol methods are implemented
assert hasattr(repository, "get_node_last_execution")
assert hasattr(repository, "get_executions_by_workflow_run")
assert hasattr(repository, "get_execution_by_id")
# Verify methods are callable
assert callable(repository.get_node_last_execution)
assert callable(repository.get_executions_by_workflow_run)
assert callable(repository.get_execution_by_id)
assert callable(repository.delete_expired_executions)
assert callable(repository.delete_executions_by_app)
assert callable(repository.get_expired_executions_batch)
assert callable(repository.delete_executions_by_ids)
def test_delete_expired_executions(self, repository):
"""Test deleting expired executions."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the select query to return some IDs first time, then empty to stop loop
execution_ids = ["id1", "id2"] # Less than batch_size to trigger break
# Mock execute method to handle both select and delete statements
def mock_execute(stmt):
mock_result = MagicMock()
# For select statements, return execution IDs
if hasattr(stmt, "limit"): # This is our select statement
mock_result.scalars.return_value.all.return_value = execution_ids
else: # This is our delete statement
mock_result.rowcount = 2
return mock_result
mock_session.execute.side_effect = mock_execute
before_date = datetime(2023, 1, 1)
# Act
result = repository.delete_expired_executions(
tenant_id="tenant-123",
before_date=before_date,
batch_size=1000,
)
# Assert
assert result == 2
assert mock_session.execute.call_count == 2 # One select call, one delete call
mock_session.commit.assert_called_once()
def test_delete_executions_by_app(self, repository):
"""Test deleting executions by app."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the select query to return some IDs first time, then empty to stop loop
execution_ids = ["id1", "id2"]
# Mock execute method to handle both select and delete statements
def mock_execute(stmt):
mock_result = MagicMock()
# For select statements, return execution IDs
if hasattr(stmt, "limit"): # This is our select statement
mock_result.scalars.return_value.all.return_value = execution_ids
else: # This is our delete statement
mock_result.rowcount = 2
return mock_result
mock_session.execute.side_effect = mock_execute
# Act
result = repository.delete_executions_by_app(
tenant_id="tenant-123",
app_id="app-456",
batch_size=1000,
)
# Assert
assert result == 2
assert mock_session.execute.call_count == 2 # One select call, one delete call
mock_session.commit.assert_called_once()
def test_get_expired_executions_batch(self, repository):
"""Test getting expired executions batch for backup."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Create mock execution objects
mock_execution1 = MagicMock()
mock_execution1.id = "exec-1"
mock_execution2 = MagicMock()
mock_execution2.id = "exec-2"
mock_session.execute.return_value.scalars.return_value.all.return_value = [mock_execution1, mock_execution2]
before_date = datetime(2023, 1, 1)
# Act
result = repository.get_expired_executions_batch(
tenant_id="tenant-123",
before_date=before_date,
batch_size=1000,
)
# Assert
assert len(result) == 2
assert result[0].id == "exec-1"
assert result[1].id == "exec-2"
mock_session.execute.assert_called_once()
def test_delete_executions_by_ids(self, repository):
"""Test deleting executions by IDs."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the delete query result
mock_result = MagicMock()
mock_result.rowcount = 3
mock_session.execute.return_value = mock_result
execution_ids = ["id1", "id2", "id3"]
# Act
result = repository.delete_executions_by_ids(execution_ids)
# Assert
assert result == 3
mock_session.execute.assert_called_once()
mock_session.commit.assert_called_once()
def test_delete_executions_by_ids_empty_list(self, repository):
"""Test deleting executions with empty ID list."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Act
result = repository.delete_executions_by_ids([])
# Assert
assert result == 0
mock_session.query.assert_not_called()
mock_session.commit.assert_not_called()

View File

@@ -10,7 +10,8 @@ from services.workflow_service import WorkflowService
class TestWorkflowService:
@pytest.fixture
def workflow_service(self):
return WorkflowService()
mock_session_maker = MagicMock()
return WorkflowService(mock_session_maker)
@pytest.fixture
def mock_app(self):

View File

@@ -799,6 +799,19 @@ WORKFLOW_FILE_UPLOAD_LIMIT=10
# hybrid: Save new data to object storage, read from both object storage and RDBMS
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
# Repository configuration
# Core workflow execution repository implementation
CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
# Core workflow node execution repository implementation
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
# API workflow node execution repository implementation
API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
# API workflow run repository implementation
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
# HTTP request node in workflow configuration
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576

View File

@@ -354,6 +354,10 @@ x-shared-env: &shared-api-worker-env
WORKFLOW_PARALLEL_DEPTH_LIMIT: ${WORKFLOW_PARALLEL_DEPTH_LIMIT:-3}
WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10}
WORKFLOW_NODE_EXECUTION_STORAGE: ${WORKFLOW_NODE_EXECUTION_STORAGE:-rdbms}
CORE_WORKFLOW_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository}
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository}
API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository}
API_WORKFLOW_RUN_REPOSITORY: ${API_WORKFLOW_RUN_REPOSITORY:-repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository}
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760}
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576}
HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True}