feat: Add an asynchronous repository to improve workflow performance (#20050)

Co-authored-by: liangxin <liangxin@shein.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: liangxin <xinlmain@gmail.com>
This commit is contained in:
xinlmain
2025-08-13 02:28:06 +08:00
committed by GitHub
parent 6e6389c930
commit b3399642c5
17 changed files with 1376 additions and 217 deletions

View File

@@ -5,7 +5,7 @@ cd web && pnpm install
pipx install uv
echo 'alias start-api="cd /workspaces/dify/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage"' >> ~/.bashrc
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
echo 'alias start-web-prod="cd /workspaces/dify/web && pnpm build && pnpm start"' >> ~/.bashrc
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc

View File

@@ -74,7 +74,7 @@
10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage
```
Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal:

View File

@@ -552,12 +552,18 @@ class RepositoryConfig(BaseSettings):
"""
CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field(
description="Repository implementation for WorkflowExecution. Specify as a module path",
description="Repository implementation for WorkflowExecution. Options: "
"'core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository' (default), "
"'core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository'",
default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository",
)
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
description="Repository implementation for WorkflowNodeExecution. Specify as a module path",
description="Repository implementation for WorkflowNodeExecution. Options: "
"'core.repositories.sqlalchemy_workflow_node_execution_repository."
"SQLAlchemyWorkflowNodeExecutionRepository' (default), "
"'core.repositories.celery_workflow_node_execution_repository."
"CeleryWorkflowNodeExecutionRepository'",
default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository",
)

View File

@@ -5,10 +5,14 @@ This package contains concrete implementations of the repository interfaces
defined in the core.workflow.repository package.
"""
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
__all__ = [
"CeleryWorkflowExecutionRepository",
"CeleryWorkflowNodeExecutionRepository",
"DifyCoreRepositoryFactory",
"RepositoryImportError",
"SQLAlchemyWorkflowNodeExecutionRepository",

View File

@@ -0,0 +1,126 @@
"""
Celery-based implementation of the WorkflowExecutionRepository.
This implementation uses Celery tasks for asynchronous storage operations,
providing improved performance by offloading database operations to background workers.
"""
import logging
from typing import Optional, Union
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from libs.helper import extract_tenant_id
from models import Account, CreatorUserRole, EndUser
from models.enums import WorkflowRunTriggeredFrom
from tasks.workflow_execution_tasks import (
save_workflow_execution_task,
)
logger = logging.getLogger(__name__)
class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
"""
Celery-based implementation of the WorkflowExecutionRepository interface.
This implementation provides asynchronous storage capabilities by using Celery tasks
to handle database operations in background workers. This improves performance by
reducing the blocking time for workflow execution storage operations.
Key features:
- Asynchronous save operations using Celery tasks
- Support for multi-tenancy through tenant/app filtering
- Automatic retry and error handling through Celery
"""
_session_factory: sessionmaker
_tenant_id: str
_app_id: Optional[str]
_triggered_from: Optional[WorkflowRunTriggeredFrom]
_creator_user_id: str
_creator_user_role: CreatorUserRole
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
app_id: Optional[str],
triggered_from: Optional[WorkflowRunTriggeredFrom],
):
"""
Initialize the repository with Celery task configuration and context information.
Args:
session_factory: SQLAlchemy sessionmaker or engine for fallback operations
user: Account or EndUser object containing tenant_id, user ID, and role information
app_id: App ID for filtering by application (can be None)
triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN)
"""
# Store session factory for fallback operations
if isinstance(session_factory, Engine):
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
elif isinstance(session_factory, sessionmaker):
self._session_factory = session_factory
else:
raise ValueError(
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
)
# Extract tenant_id from user
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
# Store app context
self._app_id = app_id
# Extract user context
self._triggered_from = triggered_from
self._creator_user_id = user.id
# Determine user role based on user type
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
logger.info(
"Initialized CeleryWorkflowExecutionRepository for tenant %s, app %s, triggered_from %s",
self._tenant_id,
self._app_id,
self._triggered_from,
)
def save(self, execution: WorkflowExecution) -> None:
"""
Save or update a WorkflowExecution instance asynchronously using Celery.
This method queues the save operation as a Celery task and returns immediately,
providing improved performance for high-throughput scenarios.
Args:
execution: The WorkflowExecution instance to save or update
"""
try:
# Serialize execution for Celery task
execution_data = execution.model_dump()
# Queue the save operation as a Celery task (fire and forget)
save_workflow_execution_task.delay(
execution_data=execution_data,
tenant_id=self._tenant_id,
app_id=self._app_id or "",
triggered_from=self._triggered_from.value if self._triggered_from else "",
creator_user_id=self._creator_user_id,
creator_user_role=self._creator_user_role.value,
)
logger.debug("Queued async save for workflow execution: %s", execution.id_)
except Exception as e:
logger.exception("Failed to queue save operation for execution %s", execution.id_)
# In case of Celery failure, we could implement a fallback to synchronous save
# For now, we'll re-raise the exception
raise

View File

@@ -0,0 +1,190 @@
"""
Celery-based implementation of the WorkflowNodeExecutionRepository.
This implementation uses Celery tasks for asynchronous storage operations,
providing improved performance by offloading database operations to background workers.
"""
import logging
from collections.abc import Sequence
from typing import Optional, Union
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
from core.workflow.repositories.workflow_node_execution_repository import (
OrderConfig,
WorkflowNodeExecutionRepository,
)
from libs.helper import extract_tenant_id
from models import Account, CreatorUserRole, EndUser
from models.workflow import WorkflowNodeExecutionTriggeredFrom
from tasks.workflow_node_execution_tasks import (
save_workflow_node_execution_task,
)
logger = logging.getLogger(__name__)
class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
"""
Celery-based implementation of the WorkflowNodeExecutionRepository interface.
This implementation provides asynchronous storage capabilities by using Celery tasks
to handle database operations in background workers. This improves performance by
reducing the blocking time for workflow node execution storage operations.
Key features:
- Asynchronous save operations using Celery tasks
- In-memory cache for immediate reads
- Support for multi-tenancy through tenant/app filtering
- Automatic retry and error handling through Celery
"""
_session_factory: sessionmaker
_tenant_id: str
_app_id: Optional[str]
_triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom]
_creator_user_id: str
_creator_user_role: CreatorUserRole
_execution_cache: dict[str, WorkflowNodeExecution]
_workflow_execution_mapping: dict[str, list[str]]
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
app_id: Optional[str],
triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom],
):
"""
Initialize the repository with Celery task configuration and context information.
Args:
session_factory: SQLAlchemy sessionmaker or engine for fallback operations
user: Account or EndUser object containing tenant_id, user ID, and role information
app_id: App ID for filtering by application (can be None)
triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN)
"""
# Store session factory for fallback operations
if isinstance(session_factory, Engine):
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
elif isinstance(session_factory, sessionmaker):
self._session_factory = session_factory
else:
raise ValueError(
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
)
# Extract tenant_id from user
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
# Store app context
self._app_id = app_id
# Extract user context
self._triggered_from = triggered_from
self._creator_user_id = user.id
# Determine user role based on user type
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
# In-memory cache for workflow node executions
self._execution_cache: dict[str, WorkflowNodeExecution] = {}
# Cache for mapping workflow_execution_ids to execution IDs for efficient retrieval
self._workflow_execution_mapping: dict[str, list[str]] = {}
logger.info(
"Initialized CeleryWorkflowNodeExecutionRepository for tenant %s, app %s, triggered_from %s",
self._tenant_id,
self._app_id,
self._triggered_from,
)
def save(self, execution: WorkflowNodeExecution) -> None:
"""
Save or update a WorkflowNodeExecution instance to cache and asynchronously to database.
This method stores the execution in cache immediately for fast reads and queues
the save operation as a Celery task without tracking the task status.
Args:
execution: The WorkflowNodeExecution instance to save or update
"""
try:
# Store in cache immediately for fast reads
self._execution_cache[execution.id] = execution
# Update workflow execution mapping for efficient retrieval
if execution.workflow_execution_id:
if execution.workflow_execution_id not in self._workflow_execution_mapping:
self._workflow_execution_mapping[execution.workflow_execution_id] = []
if execution.id not in self._workflow_execution_mapping[execution.workflow_execution_id]:
self._workflow_execution_mapping[execution.workflow_execution_id].append(execution.id)
# Serialize execution for Celery task
execution_data = execution.model_dump()
# Queue the save operation as a Celery task (fire and forget)
save_workflow_node_execution_task.delay(
execution_data=execution_data,
tenant_id=self._tenant_id,
app_id=self._app_id or "",
triggered_from=self._triggered_from.value if self._triggered_from else "",
creator_user_id=self._creator_user_id,
creator_user_role=self._creator_user_role.value,
)
logger.debug("Cached and queued async save for workflow node execution: %s", execution.id)
except Exception as e:
logger.exception("Failed to cache or queue save operation for node execution %s", execution.id)
# In case of Celery failure, we could implement a fallback to synchronous save
# For now, we'll re-raise the exception
raise
def get_by_workflow_run(
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all WorkflowNodeExecution instances for a specific workflow run from cache.
Args:
workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results
Returns:
A sequence of WorkflowNodeExecution instances
"""
try:
# Get execution IDs for this workflow run from cache
execution_ids = self._workflow_execution_mapping.get(workflow_run_id, [])
# Retrieve executions from cache
result = []
for execution_id in execution_ids:
if execution_id in self._execution_cache:
result.append(self._execution_cache[execution_id])
# Apply ordering if specified
if order_config and result:
# Sort based on the order configuration
reverse = order_config.order_direction == "desc"
# Sort by multiple fields if specified
for field_name in reversed(order_config.order_by):
result.sort(key=lambda x: getattr(x, field_name, 0), reverse=reverse)
logger.debug("Retrieved %d workflow node executions for run %s from cache", len(result), workflow_run_id)
return result
except Exception as e:
logger.exception("Failed to get workflow node executions for run %s from cache", workflow_run_id)
return []

View File

@@ -94,11 +94,9 @@ class DifyCoreRepositoryFactory:
def _validate_constructor_signature(repository_class: type, required_params: list[str]) -> None:
"""
Validate that a repository class constructor accepts required parameters.
Args:
repository_class: The class to validate
required_params: List of required parameter names
Raises:
RepositoryImportError: If the constructor doesn't accept required parameters
"""
@@ -158,10 +156,8 @@ class DifyCoreRepositoryFactory:
try:
repository_class = cls._import_class(class_path)
cls._validate_repository_interface(repository_class, WorkflowExecutionRepository)
cls._validate_constructor_signature(
repository_class, ["session_factory", "user", "app_id", "triggered_from"]
)
# All repository types now use the same constructor parameters
return repository_class( # type: ignore[no-any-return]
session_factory=session_factory,
user=user,
@@ -204,10 +200,8 @@ class DifyCoreRepositoryFactory:
try:
repository_class = cls._import_class(class_path)
cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository)
cls._validate_constructor_signature(
repository_class, ["session_factory", "user", "app_id", "triggered_from"]
)
# All repository types now use the same constructor parameters
return repository_class( # type: ignore[no-any-return]
session_factory=session_factory,
user=user,

View File

@@ -1,4 +1,5 @@
from collections.abc import Mapping
from decimal import Decimal
from typing import Any
from pydantic import BaseModel
@@ -17,6 +18,9 @@ class WorkflowRuntimeTypeConverter:
return value
if isinstance(value, (bool, int, str, float)):
return value
if isinstance(value, Decimal):
# Convert Decimal to float for JSON serialization
return float(value)
if isinstance(value, Segment):
return self._to_json_encodable_recursive(value.value)
if isinstance(value, File):

View File

@@ -32,7 +32,7 @@ if [[ "${MODE}" == "worker" ]]; then
exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
--max-tasks-per-child ${MAX_TASK_PRE_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin}
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage}
elif [[ "${MODE}" == "beat" ]]; then
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}

View File

@@ -0,0 +1,136 @@
"""
Celery tasks for asynchronous workflow execution storage operations.
These tasks provide asynchronous storage capabilities for workflow execution data,
improving performance by offloading storage operations to background workers.
"""
import json
import logging
from celery import shared_task # type: ignore[import-untyped]
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from models import CreatorUserRole, WorkflowRun
from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
def save_workflow_execution_task(
self,
execution_data: dict,
tenant_id: str,
app_id: str,
triggered_from: str,
creator_user_id: str,
creator_user_role: str,
) -> bool:
"""
Asynchronously save or update a workflow execution to the database.
Args:
execution_data: Serialized WorkflowExecution data
tenant_id: Tenant ID for multi-tenancy
app_id: Application ID
triggered_from: Source of the execution trigger
creator_user_id: ID of the user who created the execution
creator_user_role: Role of the user who created the execution
Returns:
True if successful, False otherwise
"""
try:
# Create a new session for this task
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
# Deserialize execution data
execution = WorkflowExecution.model_validate(execution_data)
# Check if workflow run already exists
existing_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == execution.id_))
if existing_run:
# Update existing workflow run
_update_workflow_run_from_execution(existing_run, execution)
logger.debug("Updated existing workflow run: %s", execution.id_)
else:
# Create new workflow run
workflow_run = _create_workflow_run_from_execution(
execution=execution,
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom(triggered_from),
creator_user_id=creator_user_id,
creator_user_role=CreatorUserRole(creator_user_role),
)
session.add(workflow_run)
logger.debug("Created new workflow run: %s", execution.id_)
session.commit()
return True
except Exception as e:
logger.exception("Failed to save workflow execution %s", execution_data.get("id_", "unknown"))
# Retry the task with exponential backoff
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
def _create_workflow_run_from_execution(
execution: WorkflowExecution,
tenant_id: str,
app_id: str,
triggered_from: WorkflowRunTriggeredFrom,
creator_user_id: str,
creator_user_role: CreatorUserRole,
) -> WorkflowRun:
"""
Create a WorkflowRun database model from a WorkflowExecution domain entity.
"""
workflow_run = WorkflowRun()
workflow_run.id = execution.id_
workflow_run.tenant_id = tenant_id
workflow_run.app_id = app_id
workflow_run.workflow_id = execution.workflow_id
workflow_run.type = execution.workflow_type.value
workflow_run.triggered_from = triggered_from.value
workflow_run.version = execution.workflow_version
json_converter = WorkflowRuntimeTypeConverter()
workflow_run.graph = json.dumps(json_converter.to_json_encodable(execution.graph))
workflow_run.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs))
workflow_run.status = execution.status.value
workflow_run.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
workflow_run.error = execution.error_message
workflow_run.elapsed_time = execution.elapsed_time
workflow_run.total_tokens = execution.total_tokens
workflow_run.total_steps = execution.total_steps
workflow_run.created_by_role = creator_user_role.value
workflow_run.created_by = creator_user_id
workflow_run.created_at = execution.started_at
workflow_run.finished_at = execution.finished_at
return workflow_run
def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution) -> None:
"""
Update a WorkflowRun database model from a WorkflowExecution domain entity.
"""
json_converter = WorkflowRuntimeTypeConverter()
workflow_run.status = execution.status.value
workflow_run.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
workflow_run.error = execution.error_message
workflow_run.elapsed_time = execution.elapsed_time
workflow_run.total_tokens = execution.total_tokens
workflow_run.total_steps = execution.total_steps
workflow_run.finished_at = execution.finished_at

View File

@@ -0,0 +1,171 @@
"""
Celery tasks for asynchronous workflow node execution storage operations.
These tasks provide asynchronous storage capabilities for workflow node execution data,
improving performance by offloading storage operations to background workers.
"""
import json
import logging
from celery import shared_task # type: ignore[import-untyped]
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
)
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from models import CreatorUserRole, WorkflowNodeExecutionModel
from models.workflow import WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
def save_workflow_node_execution_task(
self,
execution_data: dict,
tenant_id: str,
app_id: str,
triggered_from: str,
creator_user_id: str,
creator_user_role: str,
) -> bool:
"""
Asynchronously save or update a workflow node execution to the database.
Args:
execution_data: Serialized WorkflowNodeExecution data
tenant_id: Tenant ID for multi-tenancy
app_id: Application ID
triggered_from: Source of the execution trigger
creator_user_id: ID of the user who created the execution
creator_user_role: Role of the user who created the execution
Returns:
True if successful, False otherwise
"""
try:
# Create a new session for this task
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
# Deserialize execution data
execution = WorkflowNodeExecution.model_validate(execution_data)
# Check if node execution already exists
existing_execution = session.scalar(
select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution.id)
)
if existing_execution:
# Update existing node execution
_update_node_execution_from_domain(existing_execution, execution)
logger.debug("Updated existing workflow node execution: %s", execution.id)
else:
# Create new node execution
node_execution = _create_node_execution_from_domain(
execution=execution,
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom(triggered_from),
creator_user_id=creator_user_id,
creator_user_role=CreatorUserRole(creator_user_role),
)
session.add(node_execution)
logger.debug("Created new workflow node execution: %s", execution.id)
session.commit()
return True
except Exception as e:
logger.exception("Failed to save workflow node execution %s", execution_data.get("id", "unknown"))
# Retry the task with exponential backoff
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
def _create_node_execution_from_domain(
execution: WorkflowNodeExecution,
tenant_id: str,
app_id: str,
triggered_from: WorkflowNodeExecutionTriggeredFrom,
creator_user_id: str,
creator_user_role: CreatorUserRole,
) -> WorkflowNodeExecutionModel:
"""
Create a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
"""
node_execution = WorkflowNodeExecutionModel()
node_execution.id = execution.id
node_execution.tenant_id = tenant_id
node_execution.app_id = app_id
node_execution.workflow_id = execution.workflow_id
node_execution.triggered_from = triggered_from.value
node_execution.workflow_run_id = execution.workflow_execution_id
node_execution.index = execution.index
node_execution.predecessor_node_id = execution.predecessor_node_id
node_execution.node_id = execution.node_id
node_execution.node_type = execution.node_type.value
node_execution.title = execution.title
node_execution.node_execution_id = execution.node_execution_id
# Serialize complex data as JSON
json_converter = WorkflowRuntimeTypeConverter()
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
node_execution.process_data = (
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
)
node_execution.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
# Convert metadata enum keys to strings for JSON serialization
if execution.metadata:
metadata_for_json = {
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
}
node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json))
else:
node_execution.execution_metadata = "{}"
node_execution.status = execution.status.value
node_execution.error = execution.error
node_execution.elapsed_time = execution.elapsed_time
node_execution.created_by_role = creator_user_role.value
node_execution.created_by = creator_user_id
node_execution.created_at = execution.created_at
node_execution.finished_at = execution.finished_at
return node_execution
def _update_node_execution_from_domain(
node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution
) -> None:
"""
Update a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
"""
# Update serialized data
json_converter = WorkflowRuntimeTypeConverter()
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
node_execution.process_data = (
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
)
node_execution.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
# Convert metadata enum keys to strings for JSON serialization
if execution.metadata:
metadata_for_json = {
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
}
node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json))
else:
node_execution.execution_metadata = "{}"
# Update other fields
node_execution.status = execution.status.value
node_execution.error = execution.error
node_execution.elapsed_time = execution.elapsed_time
node_execution.finished_at = execution.finished_at

View File

@@ -0,0 +1,247 @@
"""
Unit tests for CeleryWorkflowExecutionRepository.
These tests verify the Celery-based asynchronous storage functionality
for workflow execution data.
"""
from datetime import UTC, datetime
from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowType
from models import Account, EndUser
from models.enums import WorkflowRunTriggeredFrom
@pytest.fixture
def mock_session_factory():
"""Mock SQLAlchemy session factory."""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# Create a real sessionmaker with in-memory SQLite for testing
engine = create_engine("sqlite:///:memory:")
return sessionmaker(bind=engine)
@pytest.fixture
def mock_account():
"""Mock Account user."""
account = Mock(spec=Account)
account.id = str(uuid4())
account.current_tenant_id = str(uuid4())
return account
@pytest.fixture
def mock_end_user():
"""Mock EndUser."""
user = Mock(spec=EndUser)
user.id = str(uuid4())
user.tenant_id = str(uuid4())
return user
@pytest.fixture
def sample_workflow_execution():
"""Sample WorkflowExecution for testing."""
return WorkflowExecution.new(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
started_at=datetime.now(UTC).replace(tzinfo=None),
)
class TestCeleryWorkflowExecutionRepository:
"""Test cases for CeleryWorkflowExecutionRepository."""
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
"""Test repository initialization with sessionmaker."""
app_id = "test-app-id"
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id=app_id,
triggered_from=triggered_from,
)
assert repo._tenant_id == mock_account.current_tenant_id
assert repo._app_id == app_id
assert repo._triggered_from == triggered_from
assert repo._creator_user_id == mock_account.id
assert repo._creator_user_role is not None
def test_init_basic_functionality(self, mock_session_factory, mock_account):
"""Test repository initialization basic functionality."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Verify basic initialization
assert repo._tenant_id == mock_account.current_tenant_id
assert repo._app_id == "test-app"
assert repo._triggered_from == WorkflowRunTriggeredFrom.DEBUGGING
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
"""Test repository initialization with EndUser."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_end_user,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert repo._tenant_id == mock_end_user.tenant_id
def test_init_without_tenant_id_raises_error(self, mock_session_factory):
"""Test that initialization fails without tenant_id."""
# Create a mock Account with no tenant_id
user = Mock(spec=Account)
user.current_tenant_id = None
user.id = str(uuid4())
with pytest.raises(ValueError, match="User must have a tenant_id"):
CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=user,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_save_queues_celery_task(self, mock_task, mock_session_factory, mock_account, sample_workflow_execution):
"""Test that save operation queues a Celery task without tracking."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
repo.save(sample_workflow_execution)
# Verify Celery task was queued with correct parameters
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args[1]
assert call_args["execution_data"] == sample_workflow_execution.model_dump()
assert call_args["tenant_id"] == mock_account.current_tenant_id
assert call_args["app_id"] == "test-app"
assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN.value
assert call_args["creator_user_id"] == mock_account.id
# Verify no task tracking occurs (no _pending_saves attribute)
assert not hasattr(repo, "_pending_saves")
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_save_handles_celery_failure(
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
):
"""Test that save operation handles Celery task failures."""
mock_task.delay.side_effect = Exception("Celery is down")
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
with pytest.raises(Exception, match="Celery is down"):
repo.save(sample_workflow_execution)
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_save_operation_fire_and_forget(
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
):
"""Test that save operation works in fire-and-forget mode."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Test that save doesn't block or maintain state
repo.save(sample_workflow_execution)
# Verify no pending saves are tracked (no _pending_saves attribute)
assert not hasattr(repo, "_pending_saves")
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_multiple_save_operations(self, mock_task, mock_session_factory, mock_account):
"""Test multiple save operations work correctly."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Create multiple executions
exec1 = WorkflowExecution.new(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
started_at=datetime.now(UTC).replace(tzinfo=None),
)
exec2 = WorkflowExecution.new(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input2": "value2"},
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Save both executions
repo.save(exec1)
repo.save(exec2)
# Should work without issues and not maintain state (no _pending_saves attribute)
assert not hasattr(repo, "_pending_saves")
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_save_with_different_user_types(self, mock_task, mock_session_factory, mock_end_user):
"""Test save operation with different user types."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_end_user,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
execution = WorkflowExecution.new(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
started_at=datetime.now(UTC).replace(tzinfo=None),
)
repo.save(execution)
# Verify task was called with EndUser context
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args[1]
assert call_args["tenant_id"] == mock_end_user.tenant_id
assert call_args["creator_user_id"] == mock_end_user.id

View File

@@ -0,0 +1,349 @@
"""
Unit tests for CeleryWorkflowNodeExecutionRepository.
These tests verify the Celery-based asynchronous storage functionality
for workflow node execution data.
"""
from datetime import UTC, datetime
from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
)
from core.workflow.nodes.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
from models import Account, EndUser
from models.workflow import WorkflowNodeExecutionTriggeredFrom
@pytest.fixture
def mock_session_factory():
"""Mock SQLAlchemy session factory."""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# Create a real sessionmaker with in-memory SQLite for testing
engine = create_engine("sqlite:///:memory:")
return sessionmaker(bind=engine)
@pytest.fixture
def mock_account():
"""Mock Account user."""
account = Mock(spec=Account)
account.id = str(uuid4())
account.current_tenant_id = str(uuid4())
return account
@pytest.fixture
def mock_end_user():
"""Mock EndUser."""
user = Mock(spec=EndUser)
user.id = str(uuid4())
user.tenant_id = str(uuid4())
return user
@pytest.fixture
def sample_workflow_node_execution():
"""Sample WorkflowNodeExecution for testing."""
return WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=str(uuid4()),
index=1,
node_id="test_node",
node_type=NodeType.START,
title="Test Node",
inputs={"input1": "value1"},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=datetime.now(UTC).replace(tzinfo=None),
)
class TestCeleryWorkflowNodeExecutionRepository:
"""Test cases for CeleryWorkflowNodeExecutionRepository."""
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
"""Test repository initialization with sessionmaker."""
app_id = "test-app-id"
triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id=app_id,
triggered_from=triggered_from,
)
assert repo._tenant_id == mock_account.current_tenant_id
assert repo._app_id == app_id
assert repo._triggered_from == triggered_from
assert repo._creator_user_id == mock_account.id
assert repo._creator_user_role is not None
def test_init_with_cache_initialized(self, mock_session_factory, mock_account):
"""Test repository initialization with cache properly initialized."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert repo._execution_cache == {}
assert repo._workflow_execution_mapping == {}
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
"""Test repository initialization with EndUser."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_end_user,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert repo._tenant_id == mock_end_user.tenant_id
def test_init_without_tenant_id_raises_error(self, mock_session_factory):
"""Test that initialization fails without tenant_id."""
# Create a mock Account with no tenant_id
user = Mock(spec=Account)
user.current_tenant_id = None
user.id = str(uuid4())
with pytest.raises(ValueError, match="User must have a tenant_id"):
CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=user,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_save_caches_and_queues_celery_task(
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
):
"""Test that save operation caches execution and queues a Celery task."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
repo.save(sample_workflow_node_execution)
# Verify Celery task was queued with correct parameters
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args[1]
assert call_args["execution_data"] == sample_workflow_node_execution.model_dump()
assert call_args["tenant_id"] == mock_account.current_tenant_id
assert call_args["app_id"] == "test-app"
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
assert call_args["creator_user_id"] == mock_account.id
# Verify execution is cached
assert sample_workflow_node_execution.id in repo._execution_cache
assert repo._execution_cache[sample_workflow_node_execution.id] == sample_workflow_node_execution
# Verify workflow execution mapping is updated
assert sample_workflow_node_execution.workflow_execution_id in repo._workflow_execution_mapping
assert (
sample_workflow_node_execution.id
in repo._workflow_execution_mapping[sample_workflow_node_execution.workflow_execution_id]
)
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_save_handles_celery_failure(
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
):
"""Test that save operation handles Celery task failures."""
mock_task.delay.side_effect = Exception("Celery is down")
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
with pytest.raises(Exception, match="Celery is down"):
repo.save(sample_workflow_node_execution)
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_get_by_workflow_run_from_cache(
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
):
"""Test that get_by_workflow_run retrieves executions from cache."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Save execution to cache first
repo.save(sample_workflow_node_execution)
workflow_run_id = sample_workflow_node_execution.workflow_execution_id
order_config = OrderConfig(order_by=["index"], order_direction="asc")
result = repo.get_by_workflow_run(workflow_run_id, order_config)
# Verify results were retrieved from cache
assert len(result) == 1
assert result[0].id == sample_workflow_node_execution.id
assert result[0] is sample_workflow_node_execution
def test_get_by_workflow_run_without_order_config(self, mock_session_factory, mock_account):
"""Test get_by_workflow_run without order configuration."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
result = repo.get_by_workflow_run("workflow-run-id")
# Should return empty list since nothing in cache
assert len(result) == 0
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_cache_operations(self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution):
"""Test cache operations work correctly."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Test saving to cache
repo.save(sample_workflow_node_execution)
# Verify cache contains the execution
assert sample_workflow_node_execution.id in repo._execution_cache
# Test retrieving from cache
result = repo.get_by_workflow_run(sample_workflow_node_execution.workflow_execution_id)
assert len(result) == 1
assert result[0].id == sample_workflow_node_execution.id
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_multiple_executions_same_workflow(self, mock_task, mock_session_factory, mock_account):
"""Test multiple executions for the same workflow."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Create multiple executions for the same workflow
workflow_run_id = str(uuid4())
exec1 = WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=workflow_run_id,
index=1,
node_id="node1",
node_type=NodeType.START,
title="Node 1",
inputs={"input1": "value1"},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=datetime.now(UTC).replace(tzinfo=None),
)
exec2 = WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=workflow_run_id,
index=2,
node_id="node2",
node_type=NodeType.LLM,
title="Node 2",
inputs={"input2": "value2"},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=datetime.now(UTC).replace(tzinfo=None),
)
# Save both executions
repo.save(exec1)
repo.save(exec2)
# Verify both are cached and mapped
assert len(repo._execution_cache) == 2
assert len(repo._workflow_execution_mapping[workflow_run_id]) == 2
# Test retrieval
result = repo.get_by_workflow_run(workflow_run_id)
assert len(result) == 2
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_ordering_functionality(self, mock_task, mock_session_factory, mock_account):
"""Test ordering functionality works correctly."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Create executions with different indices
workflow_run_id = str(uuid4())
exec1 = WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=workflow_run_id,
index=2,
node_id="node2",
node_type=NodeType.START,
title="Node 2",
inputs={},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=datetime.now(UTC).replace(tzinfo=None),
)
exec2 = WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=workflow_run_id,
index=1,
node_id="node1",
node_type=NodeType.LLM,
title="Node 1",
inputs={},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=datetime.now(UTC).replace(tzinfo=None),
)
# Save in random order
repo.save(exec1)
repo.save(exec2)
# Test ascending order
order_config = OrderConfig(order_by=["index"], order_direction="asc")
result = repo.get_by_workflow_run(workflow_run_id, order_config)
assert len(result) == 2
assert result[0].index == 1
assert result[1].index == 2
# Test descending order
order_config = OrderConfig(order_by=["index"], order_direction="desc")
result = repo.get_by_workflow_run(workflow_run_id, order_config)
assert len(result) == 2
assert result[0].index == 2
assert result[1].index == 1

View File

@@ -59,7 +59,7 @@ class TestRepositoryFactory:
def get_by_id(self):
pass
# Create a mock interface with the same methods
# Create a mock interface class
class MockInterface:
def save(self):
pass
@@ -67,20 +67,20 @@ class TestRepositoryFactory:
def get_by_id(self):
pass
# Should not raise an exception
# Should not raise an exception when all methods are present
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
def test_validate_repository_interface_missing_methods(self):
"""Test interface validation with missing methods."""
# Create a mock class that doesn't implement all required methods
# Create a mock class that's missing required methods
class IncompleteRepository:
def save(self):
pass
# Missing get_by_id method
# Create a mock interface with required methods
# Create a mock interface that requires both methods
class MockInterface:
def save(self):
pass
@@ -88,57 +88,39 @@ class TestRepositoryFactory:
def get_by_id(self):
pass
def missing_method(self):
pass
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface)
assert "does not implement required methods" in str(exc_info.value)
assert "get_by_id" in str(exc_info.value)
def test_validate_constructor_signature_success(self):
"""Test successful constructor signature validation."""
def test_validate_repository_interface_with_private_methods(self):
"""Test that private methods are ignored during interface validation."""
class MockRepository:
def __init__(self, session_factory, user, app_id, triggered_from):
def save(self):
pass
# Should not raise an exception
DifyCoreRepositoryFactory._validate_constructor_signature(
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
def test_validate_constructor_signature_missing_params(self):
"""Test constructor validation with missing parameters."""
class IncompleteRepository:
def __init__(self, session_factory, user):
# Missing app_id and triggered_from parameters
def _private_method(self):
pass
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_constructor_signature(
IncompleteRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
assert "does not accept required parameters" in str(exc_info.value)
assert "app_id" in str(exc_info.value)
assert "triggered_from" in str(exc_info.value)
def test_validate_constructor_signature_inspection_error(self, mocker: MockerFixture):
"""Test constructor validation when inspection fails."""
# Mock inspect.signature to raise an exception
mocker.patch("inspect.signature", side_effect=Exception("Inspection failed"))
class MockRepository:
def __init__(self, session_factory):
# Create a mock interface with private methods
class MockInterface:
def save(self):
pass
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"])
assert "Failed to validate constructor signature" in str(exc_info.value)
def _private_method(self):
pass
# Should not raise exception - private methods should be ignored
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_execution_repository_success(self, mock_config, mocker: MockerFixture):
"""Test successful creation of WorkflowExecutionRepository."""
def test_create_workflow_execution_repository_success(self, mock_config):
"""Test successful WorkflowExecutionRepository creation."""
# Setup mock configuration
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
# Create mock dependencies
mock_session_factory = MagicMock(spec=sessionmaker)
@@ -146,7 +128,7 @@ class TestRepositoryFactory:
app_id = "test-app-id"
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
# Mock the imported class to be a valid repository
# Create mock repository class and instance
mock_repository_class = MagicMock()
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
@@ -155,7 +137,6 @@ class TestRepositoryFactory:
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
@@ -177,7 +158,7 @@ class TestRepositoryFactory:
def test_create_workflow_execution_repository_import_error(self, mock_config):
"""Test WorkflowExecutionRepository creation with import error."""
# Setup mock configuration with invalid class path
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=Account)
@@ -195,45 +176,46 @@ class TestRepositoryFactory:
def test_create_workflow_execution_repository_validation_error(self, mock_config, mocker: MockerFixture):
"""Test WorkflowExecutionRepository creation with validation error."""
# Setup mock configuration
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=Account)
# Mock import to succeed but validation to fail
# Mock the import to succeed but validation to fail
mock_repository_class = MagicMock()
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(
DifyCoreRepositoryFactory,
"_validate_repository_interface",
side_effect=RepositoryImportError("Interface validation failed"),
),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert "Interface validation failed" in str(exc_info.value)
mocker.patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class)
mocker.patch.object(
DifyCoreRepositoryFactory,
"_validate_repository_interface",
side_effect=RepositoryImportError("Interface validation failed"),
)
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert "Interface validation failed" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_execution_repository_instantiation_error(self, mock_config, mocker: MockerFixture):
def test_create_workflow_execution_repository_instantiation_error(self, mock_config):
"""Test WorkflowExecutionRepository creation with instantiation error."""
# Setup mock configuration
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=Account)
# Mock import and validation to succeed but instantiation to fail
mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
# Create a mock repository class that raises exception on instantiation
mock_repository_class = MagicMock()
mock_repository_class.side_effect = Exception("Instantiation failed")
# Mock the validation methods to succeed
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository(
@@ -245,18 +227,18 @@ class TestRepositoryFactory:
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_success(self, mock_config, mocker: MockerFixture):
"""Test successful creation of WorkflowNodeExecutionRepository."""
def test_create_workflow_node_execution_repository_success(self, mock_config):
"""Test successful WorkflowNodeExecutionRepository creation."""
# Setup mock configuration
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
# Create mock dependencies
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
app_id = "test-app-id"
triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP
# Mock the imported class to be a valid repository
# Create mock repository class and instance
mock_repository_class = MagicMock()
mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
@@ -265,7 +247,6 @@ class TestRepositoryFactory:
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
@@ -287,7 +268,7 @@ class TestRepositoryFactory:
def test_create_workflow_node_execution_repository_import_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with import error."""
# Setup mock configuration with invalid class path
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
@@ -297,28 +278,83 @@ class TestRepositoryFactory:
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert "Cannot import repository class" in str(exc_info.value)
def test_repository_import_error_exception(self):
"""Test RepositoryImportError exception."""
error_message = "Test error message"
exception = RepositoryImportError(error_message)
assert str(exception) == error_message
assert isinstance(exception, Exception)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_validation_error(self, mock_config, mocker: MockerFixture):
"""Test WorkflowNodeExecutionRepository creation with validation error."""
# Setup mock configuration
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
# Mock the import to succeed but validation to fail
mock_repository_class = MagicMock()
mocker.patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class)
mocker.patch.object(
DifyCoreRepositoryFactory,
"_validate_repository_interface",
side_effect=RepositoryImportError("Interface validation failed"),
)
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert "Interface validation failed" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_with_engine_instead_of_sessionmaker(self, mock_config, mocker: MockerFixture):
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with instantiation error."""
# Setup mock configuration
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
# Create a mock repository class that raises exception on instantiation
mock_repository_class = MagicMock()
mock_repository_class.side_effect = Exception("Instantiation failed")
# Mock the validation methods to succeed
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
def test_repository_import_error_exception(self):
"""Test RepositoryImportError exception handling."""
error_message = "Custom error message"
error = RepositoryImportError(error_message)
assert str(error) == error_message
@patch("core.repositories.factory.dify_config")
def test_create_with_engine_instead_of_sessionmaker(self, mock_config):
"""Test repository creation with Engine instead of sessionmaker."""
# Setup mock configuration
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
# Create mock dependencies with Engine instead of sessionmaker
# Create mock dependencies using Engine instead of sessionmaker
mock_engine = MagicMock(spec=Engine)
mock_user = MagicMock(spec=Account)
app_id = "test-app-id"
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
# Mock the imported class to be a valid repository
# Create mock repository class and instance
mock_repository_class = MagicMock()
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
@@ -327,129 +363,19 @@ class TestRepositoryFactory:
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_engine, # Using Engine instead of sessionmaker
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
app_id=app_id,
triggered_from=triggered_from,
)
# Verify the repository was created with the Engine
# Verify the repository was created with correct parameters
mock_repository_class.assert_called_once_with(
session_factory=mock_engine,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
app_id=app_id,
triggered_from=triggered_from,
)
assert result is mock_repository_instance
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_validation_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with validation error."""
# Setup mock configuration
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
# Mock import to succeed but validation to fail
mock_repository_class = MagicMock()
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(
DifyCoreRepositoryFactory,
"_validate_repository_interface",
side_effect=RepositoryImportError("Interface validation failed"),
),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert "Interface validation failed" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with instantiation error."""
# Setup mock configuration
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
# Mock import and validation to succeed but instantiation to fail
mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
def test_validate_repository_interface_with_private_methods(self):
"""Test interface validation ignores private methods."""
# Create a mock class with private methods
class MockRepository:
def save(self):
pass
def get_by_id(self):
pass
def _private_method(self):
pass
# Create a mock interface with private methods
class MockInterface:
def save(self):
pass
def get_by_id(self):
pass
def _private_method(self):
pass
# Should not raise an exception (private methods are ignored)
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
def test_validate_constructor_signature_with_extra_params(self):
"""Test constructor validation with extra parameters (should pass)."""
class MockRepository:
def __init__(self, session_factory, user, app_id, triggered_from, extra_param=None):
pass
# Should not raise an exception (extra parameters are allowed)
DifyCoreRepositoryFactory._validate_constructor_signature(
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
def test_validate_constructor_signature_with_kwargs(self):
"""Test constructor validation with **kwargs (current implementation doesn't support this)."""
class MockRepository:
def __init__(self, session_factory, user, **kwargs):
pass
# Current implementation doesn't handle **kwargs, so this should raise an exception
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_constructor_signature(
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
assert "does not accept required parameters" in str(exc_info.value)
assert "app_id" in str(exc_info.value)
assert "triggered_from" in str(exc_info.value)

View File

@@ -8,4 +8,4 @@ cd "$SCRIPT_DIR/.."
uv --directory api run \
celery -A app.celery worker \
-P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion
-P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage

View File

@@ -861,17 +861,23 @@ WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
# Repository configuration
# Core workflow execution repository implementation
# Options:
# - core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository (default)
# - core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository
CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
# Core workflow node execution repository implementation
# Options:
# - core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository (default)
# - core.repositories.celery_workflow_node_execution_repository.CeleryWorkflowNodeExecutionRepository
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
# API workflow node execution repository implementation
API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
# API workflow run repository implementation
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
# API workflow node execution repository implementation
API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
# HTTP request node in workflow configuration
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576

View File

@@ -390,8 +390,8 @@ x-shared-env: &shared-api-worker-env
WORKFLOW_NODE_EXECUTION_STORAGE: ${WORKFLOW_NODE_EXECUTION_STORAGE:-rdbms}
CORE_WORKFLOW_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository}
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository}
API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository}
API_WORKFLOW_RUN_REPOSITORY: ${API_WORKFLOW_RUN_REPOSITORY:-repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository}
API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository}
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760}
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576}
HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True}