feat: Add an asynchronous repository to improve workflow performance (#20050)
Co-authored-by: liangxin <liangxin@shein.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: liangxin <xinlmain@gmail.com>
This commit is contained in:
@@ -5,7 +5,7 @@ cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
echo 'alias start-api="cd /workspaces/dify/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage"' >> ~/.bashrc
|
||||
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
|
||||
echo 'alias start-web-prod="cd /workspaces/dify/web && pnpm build && pnpm start"' >> ~/.bashrc
|
||||
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc
|
||||
|
@@ -74,7 +74,7 @@
|
||||
10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin
|
||||
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage
|
||||
```
|
||||
|
||||
Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal:
|
||||
|
@@ -552,12 +552,18 @@ class RepositoryConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field(
|
||||
description="Repository implementation for WorkflowExecution. Specify as a module path",
|
||||
description="Repository implementation for WorkflowExecution. Options: "
|
||||
"'core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository' (default), "
|
||||
"'core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository'",
|
||||
default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository",
|
||||
)
|
||||
|
||||
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
|
||||
description="Repository implementation for WorkflowNodeExecution. Specify as a module path",
|
||||
description="Repository implementation for WorkflowNodeExecution. Options: "
|
||||
"'core.repositories.sqlalchemy_workflow_node_execution_repository."
|
||||
"SQLAlchemyWorkflowNodeExecutionRepository' (default), "
|
||||
"'core.repositories.celery_workflow_node_execution_repository."
|
||||
"CeleryWorkflowNodeExecutionRepository'",
|
||||
default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository",
|
||||
)
|
||||
|
||||
|
@@ -5,10 +5,14 @@ This package contains concrete implementations of the repository interfaces
|
||||
defined in the core.workflow.repository package.
|
||||
"""
|
||||
|
||||
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
|
||||
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
__all__ = [
|
||||
"CeleryWorkflowExecutionRepository",
|
||||
"CeleryWorkflowNodeExecutionRepository",
|
||||
"DifyCoreRepositoryFactory",
|
||||
"RepositoryImportError",
|
||||
"SQLAlchemyWorkflowNodeExecutionRepository",
|
||||
|
126
api/core/repositories/celery_workflow_execution_repository.py
Normal file
126
api/core/repositories/celery_workflow_execution_repository.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""
|
||||
Celery-based implementation of the WorkflowExecutionRepository.
|
||||
|
||||
This implementation uses Celery tasks for asynchronous storage operations,
|
||||
providing improved performance by offloading database operations to background workers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import Account, CreatorUserRole, EndUser
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from tasks.workflow_execution_tasks import (
|
||||
save_workflow_execution_task,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
"""
|
||||
Celery-based implementation of the WorkflowExecutionRepository interface.
|
||||
|
||||
This implementation provides asynchronous storage capabilities by using Celery tasks
|
||||
to handle database operations in background workers. This improves performance by
|
||||
reducing the blocking time for workflow execution storage operations.
|
||||
|
||||
Key features:
|
||||
- Asynchronous save operations using Celery tasks
|
||||
- Support for multi-tenancy through tenant/app filtering
|
||||
- Automatic retry and error handling through Celery
|
||||
"""
|
||||
|
||||
_session_factory: sessionmaker
|
||||
_tenant_id: str
|
||||
_app_id: Optional[str]
|
||||
_triggered_from: Optional[WorkflowRunTriggeredFrom]
|
||||
_creator_user_id: str
|
||||
_creator_user_role: CreatorUserRole
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: sessionmaker | Engine,
|
||||
user: Union[Account, EndUser],
|
||||
app_id: Optional[str],
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom],
|
||||
):
|
||||
"""
|
||||
Initialize the repository with Celery task configuration and context information.
|
||||
|
||||
Args:
|
||||
session_factory: SQLAlchemy sessionmaker or engine for fallback operations
|
||||
user: Account or EndUser object containing tenant_id, user ID, and role information
|
||||
app_id: App ID for filtering by application (can be None)
|
||||
triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN)
|
||||
"""
|
||||
# Store session factory for fallback operations
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
elif isinstance(session_factory, sessionmaker):
|
||||
self._session_factory = session_factory
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
|
||||
)
|
||||
|
||||
# Extract tenant_id from user
|
||||
tenant_id = extract_tenant_id(user)
|
||||
if not tenant_id:
|
||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
|
||||
|
||||
# Store app context
|
||||
self._app_id = app_id
|
||||
|
||||
# Extract user context
|
||||
self._triggered_from = triggered_from
|
||||
self._creator_user_id = user.id
|
||||
|
||||
# Determine user role based on user type
|
||||
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
|
||||
|
||||
logger.info(
|
||||
"Initialized CeleryWorkflowExecutionRepository for tenant %s, app %s, triggered_from %s",
|
||||
self._tenant_id,
|
||||
self._app_id,
|
||||
self._triggered_from,
|
||||
)
|
||||
|
||||
def save(self, execution: WorkflowExecution) -> None:
|
||||
"""
|
||||
Save or update a WorkflowExecution instance asynchronously using Celery.
|
||||
|
||||
This method queues the save operation as a Celery task and returns immediately,
|
||||
providing improved performance for high-throughput scenarios.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowExecution instance to save or update
|
||||
"""
|
||||
try:
|
||||
# Serialize execution for Celery task
|
||||
execution_data = execution.model_dump()
|
||||
|
||||
# Queue the save operation as a Celery task (fire and forget)
|
||||
save_workflow_execution_task.delay(
|
||||
execution_data=execution_data,
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id or "",
|
||||
triggered_from=self._triggered_from.value if self._triggered_from else "",
|
||||
creator_user_id=self._creator_user_id,
|
||||
creator_user_role=self._creator_user_role.value,
|
||||
)
|
||||
|
||||
logger.debug("Queued async save for workflow execution: %s", execution.id_)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to queue save operation for execution %s", execution.id_)
|
||||
# In case of Celery failure, we could implement a fallback to synchronous save
|
||||
# For now, we'll re-raise the exception
|
||||
raise
|
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
Celery-based implementation of the WorkflowNodeExecutionRepository.
|
||||
|
||||
This implementation uses Celery tasks for asynchronous storage operations,
|
||||
providing improved performance by offloading database operations to background workers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
|
||||
from core.workflow.repositories.workflow_node_execution_repository import (
|
||||
OrderConfig,
|
||||
WorkflowNodeExecutionRepository,
|
||||
)
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import Account, CreatorUserRole, EndUser
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
from tasks.workflow_node_execution_tasks import (
|
||||
save_workflow_node_execution_task,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
||||
"""
|
||||
Celery-based implementation of the WorkflowNodeExecutionRepository interface.
|
||||
|
||||
This implementation provides asynchronous storage capabilities by using Celery tasks
|
||||
to handle database operations in background workers. This improves performance by
|
||||
reducing the blocking time for workflow node execution storage operations.
|
||||
|
||||
Key features:
|
||||
- Asynchronous save operations using Celery tasks
|
||||
- In-memory cache for immediate reads
|
||||
- Support for multi-tenancy through tenant/app filtering
|
||||
- Automatic retry and error handling through Celery
|
||||
"""
|
||||
|
||||
_session_factory: sessionmaker
|
||||
_tenant_id: str
|
||||
_app_id: Optional[str]
|
||||
_triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom]
|
||||
_creator_user_id: str
|
||||
_creator_user_role: CreatorUserRole
|
||||
_execution_cache: dict[str, WorkflowNodeExecution]
|
||||
_workflow_execution_mapping: dict[str, list[str]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: sessionmaker | Engine,
|
||||
user: Union[Account, EndUser],
|
||||
app_id: Optional[str],
|
||||
triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom],
|
||||
):
|
||||
"""
|
||||
Initialize the repository with Celery task configuration and context information.
|
||||
|
||||
Args:
|
||||
session_factory: SQLAlchemy sessionmaker or engine for fallback operations
|
||||
user: Account or EndUser object containing tenant_id, user ID, and role information
|
||||
app_id: App ID for filtering by application (can be None)
|
||||
triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN)
|
||||
"""
|
||||
# Store session factory for fallback operations
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
elif isinstance(session_factory, sessionmaker):
|
||||
self._session_factory = session_factory
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
|
||||
)
|
||||
|
||||
# Extract tenant_id from user
|
||||
tenant_id = extract_tenant_id(user)
|
||||
if not tenant_id:
|
||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
|
||||
|
||||
# Store app context
|
||||
self._app_id = app_id
|
||||
|
||||
# Extract user context
|
||||
self._triggered_from = triggered_from
|
||||
self._creator_user_id = user.id
|
||||
|
||||
# Determine user role based on user type
|
||||
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
|
||||
|
||||
# In-memory cache for workflow node executions
|
||||
self._execution_cache: dict[str, WorkflowNodeExecution] = {}
|
||||
|
||||
# Cache for mapping workflow_execution_ids to execution IDs for efficient retrieval
|
||||
self._workflow_execution_mapping: dict[str, list[str]] = {}
|
||||
|
||||
logger.info(
|
||||
"Initialized CeleryWorkflowNodeExecutionRepository for tenant %s, app %s, triggered_from %s",
|
||||
self._tenant_id,
|
||||
self._app_id,
|
||||
self._triggered_from,
|
||||
)
|
||||
|
||||
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Save or update a WorkflowNodeExecution instance to cache and asynchronously to database.
|
||||
|
||||
This method stores the execution in cache immediately for fast reads and queues
|
||||
the save operation as a Celery task without tracking the task status.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowNodeExecution instance to save or update
|
||||
"""
|
||||
try:
|
||||
# Store in cache immediately for fast reads
|
||||
self._execution_cache[execution.id] = execution
|
||||
|
||||
# Update workflow execution mapping for efficient retrieval
|
||||
if execution.workflow_execution_id:
|
||||
if execution.workflow_execution_id not in self._workflow_execution_mapping:
|
||||
self._workflow_execution_mapping[execution.workflow_execution_id] = []
|
||||
if execution.id not in self._workflow_execution_mapping[execution.workflow_execution_id]:
|
||||
self._workflow_execution_mapping[execution.workflow_execution_id].append(execution.id)
|
||||
|
||||
# Serialize execution for Celery task
|
||||
execution_data = execution.model_dump()
|
||||
|
||||
# Queue the save operation as a Celery task (fire and forget)
|
||||
save_workflow_node_execution_task.delay(
|
||||
execution_data=execution_data,
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id or "",
|
||||
triggered_from=self._triggered_from.value if self._triggered_from else "",
|
||||
creator_user_id=self._creator_user_id,
|
||||
creator_user_role=self._creator_user_role.value,
|
||||
)
|
||||
|
||||
logger.debug("Cached and queued async save for workflow node execution: %s", execution.id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to cache or queue save operation for node execution %s", execution.id)
|
||||
# In case of Celery failure, we could implement a fallback to synchronous save
|
||||
# For now, we'll re-raise the exception
|
||||
raise
|
||||
|
||||
def get_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all WorkflowNodeExecution instances for a specific workflow run from cache.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
order_config: Optional configuration for ordering results
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowNodeExecution instances
|
||||
"""
|
||||
try:
|
||||
# Get execution IDs for this workflow run from cache
|
||||
execution_ids = self._workflow_execution_mapping.get(workflow_run_id, [])
|
||||
|
||||
# Retrieve executions from cache
|
||||
result = []
|
||||
for execution_id in execution_ids:
|
||||
if execution_id in self._execution_cache:
|
||||
result.append(self._execution_cache[execution_id])
|
||||
|
||||
# Apply ordering if specified
|
||||
if order_config and result:
|
||||
# Sort based on the order configuration
|
||||
reverse = order_config.order_direction == "desc"
|
||||
|
||||
# Sort by multiple fields if specified
|
||||
for field_name in reversed(order_config.order_by):
|
||||
result.sort(key=lambda x: getattr(x, field_name, 0), reverse=reverse)
|
||||
|
||||
logger.debug("Retrieved %d workflow node executions for run %s from cache", len(result), workflow_run_id)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get workflow node executions for run %s from cache", workflow_run_id)
|
||||
return []
|
@@ -94,11 +94,9 @@ class DifyCoreRepositoryFactory:
|
||||
def _validate_constructor_signature(repository_class: type, required_params: list[str]) -> None:
|
||||
"""
|
||||
Validate that a repository class constructor accepts required parameters.
|
||||
|
||||
Args:
|
||||
repository_class: The class to validate
|
||||
required_params: List of required parameter names
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the constructor doesn't accept required parameters
|
||||
"""
|
||||
@@ -158,10 +156,8 @@ class DifyCoreRepositoryFactory:
|
||||
try:
|
||||
repository_class = cls._import_class(class_path)
|
||||
cls._validate_repository_interface(repository_class, WorkflowExecutionRepository)
|
||||
cls._validate_constructor_signature(
|
||||
repository_class, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
|
||||
# All repository types now use the same constructor parameters
|
||||
return repository_class( # type: ignore[no-any-return]
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
@@ -204,10 +200,8 @@ class DifyCoreRepositoryFactory:
|
||||
try:
|
||||
repository_class = cls._import_class(class_path)
|
||||
cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository)
|
||||
cls._validate_constructor_signature(
|
||||
repository_class, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
|
||||
# All repository types now use the same constructor parameters
|
||||
return repository_class( # type: ignore[no-any-return]
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -17,6 +18,9 @@ class WorkflowRuntimeTypeConverter:
|
||||
return value
|
||||
if isinstance(value, (bool, int, str, float)):
|
||||
return value
|
||||
if isinstance(value, Decimal):
|
||||
# Convert Decimal to float for JSON serialization
|
||||
return float(value)
|
||||
if isinstance(value, Segment):
|
||||
return self._to_json_encodable_recursive(value.value)
|
||||
if isinstance(value, File):
|
||||
|
@@ -32,7 +32,7 @@ if [[ "${MODE}" == "worker" ]]; then
|
||||
|
||||
exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
|
||||
--max-tasks-per-child ${MAX_TASK_PRE_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
|
||||
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin}
|
||||
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage}
|
||||
|
||||
elif [[ "${MODE}" == "beat" ]]; then
|
||||
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
|
||||
|
136
api/tasks/workflow_execution_tasks.py
Normal file
136
api/tasks/workflow_execution_tasks.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Celery tasks for asynchronous workflow execution storage operations.
|
||||
|
||||
These tasks provide asynchronous storage capabilities for workflow execution data,
|
||||
improving performance by offloading storage operations to background workers.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from celery import shared_task # type: ignore[import-untyped]
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
from models import CreatorUserRole, WorkflowRun
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
|
||||
def save_workflow_execution_task(
|
||||
self,
|
||||
execution_data: dict,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: str,
|
||||
creator_user_id: str,
|
||||
creator_user_role: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Asynchronously save or update a workflow execution to the database.
|
||||
|
||||
Args:
|
||||
execution_data: Serialized WorkflowExecution data
|
||||
tenant_id: Tenant ID for multi-tenancy
|
||||
app_id: Application ID
|
||||
triggered_from: Source of the execution trigger
|
||||
creator_user_id: ID of the user who created the execution
|
||||
creator_user_role: Role of the user who created the execution
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Create a new session for this task
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
with session_factory() as session:
|
||||
# Deserialize execution data
|
||||
execution = WorkflowExecution.model_validate(execution_data)
|
||||
|
||||
# Check if workflow run already exists
|
||||
existing_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == execution.id_))
|
||||
|
||||
if existing_run:
|
||||
# Update existing workflow run
|
||||
_update_workflow_run_from_execution(existing_run, execution)
|
||||
logger.debug("Updated existing workflow run: %s", execution.id_)
|
||||
else:
|
||||
# Create new workflow run
|
||||
workflow_run = _create_workflow_run_from_execution(
|
||||
execution=execution,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom(triggered_from),
|
||||
creator_user_id=creator_user_id,
|
||||
creator_user_role=CreatorUserRole(creator_user_role),
|
||||
)
|
||||
session.add(workflow_run)
|
||||
logger.debug("Created new workflow run: %s", execution.id_)
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to save workflow execution %s", execution_data.get("id_", "unknown"))
|
||||
# Retry the task with exponential backoff
|
||||
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
|
||||
|
||||
|
||||
def _create_workflow_run_from_execution(
|
||||
execution: WorkflowExecution,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: WorkflowRunTriggeredFrom,
|
||||
creator_user_id: str,
|
||||
creator_user_role: CreatorUserRole,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Create a WorkflowRun database model from a WorkflowExecution domain entity.
|
||||
"""
|
||||
workflow_run = WorkflowRun()
|
||||
workflow_run.id = execution.id_
|
||||
workflow_run.tenant_id = tenant_id
|
||||
workflow_run.app_id = app_id
|
||||
workflow_run.workflow_id = execution.workflow_id
|
||||
workflow_run.type = execution.workflow_type.value
|
||||
workflow_run.triggered_from = triggered_from.value
|
||||
workflow_run.version = execution.workflow_version
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
workflow_run.graph = json.dumps(json_converter.to_json_encodable(execution.graph))
|
||||
workflow_run.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs))
|
||||
workflow_run.status = execution.status.value
|
||||
workflow_run.outputs = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
|
||||
)
|
||||
workflow_run.error = execution.error_message
|
||||
workflow_run.elapsed_time = execution.elapsed_time
|
||||
workflow_run.total_tokens = execution.total_tokens
|
||||
workflow_run.total_steps = execution.total_steps
|
||||
workflow_run.created_by_role = creator_user_role.value
|
||||
workflow_run.created_by = creator_user_id
|
||||
workflow_run.created_at = execution.started_at
|
||||
workflow_run.finished_at = execution.finished_at
|
||||
|
||||
return workflow_run
|
||||
|
||||
|
||||
def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution) -> None:
|
||||
"""
|
||||
Update a WorkflowRun database model from a WorkflowExecution domain entity.
|
||||
"""
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
workflow_run.status = execution.status.value
|
||||
workflow_run.outputs = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
|
||||
)
|
||||
workflow_run.error = execution.error_message
|
||||
workflow_run.elapsed_time = execution.elapsed_time
|
||||
workflow_run.total_tokens = execution.total_tokens
|
||||
workflow_run.total_steps = execution.total_steps
|
||||
workflow_run.finished_at = execution.finished_at
|
171
api/tasks/workflow_node_execution_tasks.py
Normal file
171
api/tasks/workflow_node_execution_tasks.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
Celery tasks for asynchronous workflow node execution storage operations.
|
||||
|
||||
These tasks provide asynchronous storage capabilities for workflow node execution data,
|
||||
improving performance by offloading storage operations to background workers.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from celery import shared_task # type: ignore[import-untyped]
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
)
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
from models import CreatorUserRole, WorkflowNodeExecutionModel
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
|
||||
def save_workflow_node_execution_task(
|
||||
self,
|
||||
execution_data: dict,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: str,
|
||||
creator_user_id: str,
|
||||
creator_user_role: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Asynchronously save or update a workflow node execution to the database.
|
||||
|
||||
Args:
|
||||
execution_data: Serialized WorkflowNodeExecution data
|
||||
tenant_id: Tenant ID for multi-tenancy
|
||||
app_id: Application ID
|
||||
triggered_from: Source of the execution trigger
|
||||
creator_user_id: ID of the user who created the execution
|
||||
creator_user_role: Role of the user who created the execution
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Create a new session for this task
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
with session_factory() as session:
|
||||
# Deserialize execution data
|
||||
execution = WorkflowNodeExecution.model_validate(execution_data)
|
||||
|
||||
# Check if node execution already exists
|
||||
existing_execution = session.scalar(
|
||||
select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution.id)
|
||||
)
|
||||
|
||||
if existing_execution:
|
||||
# Update existing node execution
|
||||
_update_node_execution_from_domain(existing_execution, execution)
|
||||
logger.debug("Updated existing workflow node execution: %s", execution.id)
|
||||
else:
|
||||
# Create new node execution
|
||||
node_execution = _create_node_execution_from_domain(
|
||||
execution=execution,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom(triggered_from),
|
||||
creator_user_id=creator_user_id,
|
||||
creator_user_role=CreatorUserRole(creator_user_role),
|
||||
)
|
||||
session.add(node_execution)
|
||||
logger.debug("Created new workflow node execution: %s", execution.id)
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to save workflow node execution %s", execution_data.get("id", "unknown"))
|
||||
# Retry the task with exponential backoff
|
||||
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
|
||||
|
||||
|
||||
def _create_node_execution_from_domain(
|
||||
execution: WorkflowNodeExecution,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: WorkflowNodeExecutionTriggeredFrom,
|
||||
creator_user_id: str,
|
||||
creator_user_role: CreatorUserRole,
|
||||
) -> WorkflowNodeExecutionModel:
|
||||
"""
|
||||
Create a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
|
||||
"""
|
||||
node_execution = WorkflowNodeExecutionModel()
|
||||
node_execution.id = execution.id
|
||||
node_execution.tenant_id = tenant_id
|
||||
node_execution.app_id = app_id
|
||||
node_execution.workflow_id = execution.workflow_id
|
||||
node_execution.triggered_from = triggered_from.value
|
||||
node_execution.workflow_run_id = execution.workflow_execution_id
|
||||
node_execution.index = execution.index
|
||||
node_execution.predecessor_node_id = execution.predecessor_node_id
|
||||
node_execution.node_id = execution.node_id
|
||||
node_execution.node_type = execution.node_type.value
|
||||
node_execution.title = execution.title
|
||||
node_execution.node_execution_id = execution.node_execution_id
|
||||
|
||||
# Serialize complex data as JSON
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
|
||||
node_execution.process_data = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
|
||||
)
|
||||
node_execution.outputs = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
|
||||
)
|
||||
# Convert metadata enum keys to strings for JSON serialization
|
||||
if execution.metadata:
|
||||
metadata_for_json = {
|
||||
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
|
||||
}
|
||||
node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json))
|
||||
else:
|
||||
node_execution.execution_metadata = "{}"
|
||||
|
||||
node_execution.status = execution.status.value
|
||||
node_execution.error = execution.error
|
||||
node_execution.elapsed_time = execution.elapsed_time
|
||||
node_execution.created_by_role = creator_user_role.value
|
||||
node_execution.created_by = creator_user_id
|
||||
node_execution.created_at = execution.created_at
|
||||
node_execution.finished_at = execution.finished_at
|
||||
|
||||
return node_execution
|
||||
|
||||
|
||||
def _update_node_execution_from_domain(
|
||||
node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution
|
||||
) -> None:
|
||||
"""
|
||||
Update a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
|
||||
"""
|
||||
# Update serialized data
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
|
||||
node_execution.process_data = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
|
||||
)
|
||||
node_execution.outputs = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
|
||||
)
|
||||
# Convert metadata enum keys to strings for JSON serialization
|
||||
if execution.metadata:
|
||||
metadata_for_json = {
|
||||
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
|
||||
}
|
||||
node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json))
|
||||
else:
|
||||
node_execution.execution_metadata = "{}"
|
||||
|
||||
# Update other fields
|
||||
node_execution.status = execution.status.value
|
||||
node_execution.error = execution.error
|
||||
node_execution.elapsed_time = execution.elapsed_time
|
||||
node_execution.finished_at = execution.finished_at
|
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
Unit tests for CeleryWorkflowExecutionRepository.
|
||||
|
||||
These tests verify the Celery-based asynchronous storage functionality
|
||||
for workflow execution data.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowType
|
||||
from models import Account, EndUser
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory():
|
||||
"""Mock SQLAlchemy session factory."""
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# Create a real sessionmaker with in-memory SQLite for testing
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
"""Mock Account user."""
|
||||
account = Mock(spec=Account)
|
||||
account.id = str(uuid4())
|
||||
account.current_tenant_id = str(uuid4())
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_end_user():
|
||||
"""Mock EndUser."""
|
||||
user = Mock(spec=EndUser)
|
||||
user.id = str(uuid4())
|
||||
user.tenant_id = str(uuid4())
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_execution():
|
||||
"""Sample WorkflowExecution for testing."""
|
||||
return WorkflowExecution.new(
|
||||
id_=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input1": "value1"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
|
||||
class TestCeleryWorkflowExecutionRepository:
|
||||
"""Test cases for CeleryWorkflowExecutionRepository."""
|
||||
|
||||
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
|
||||
"""Test repository initialization with sessionmaker."""
|
||||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
assert repo._tenant_id == mock_account.current_tenant_id
|
||||
assert repo._app_id == app_id
|
||||
assert repo._triggered_from == triggered_from
|
||||
assert repo._creator_user_id == mock_account.id
|
||||
assert repo._creator_user_role is not None
|
||||
|
||||
def test_init_basic_functionality(self, mock_session_factory, mock_account):
|
||||
"""Test repository initialization basic functionality."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
)
|
||||
|
||||
# Verify basic initialization
|
||||
assert repo._tenant_id == mock_account.current_tenant_id
|
||||
assert repo._app_id == "test-app"
|
||||
assert repo._triggered_from == WorkflowRunTriggeredFrom.DEBUGGING
|
||||
|
||||
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
|
||||
"""Test repository initialization with EndUser."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_end_user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
assert repo._tenant_id == mock_end_user.tenant_id
|
||||
|
||||
def test_init_without_tenant_id_raises_error(self, mock_session_factory):
|
||||
"""Test that initialization fails without tenant_id."""
|
||||
# Create a mock Account with no tenant_id
|
||||
user = Mock(spec=Account)
|
||||
user.current_tenant_id = None
|
||||
user.id = str(uuid4())
|
||||
|
||||
with pytest.raises(ValueError, match="User must have a tenant_id"):
|
||||
CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_save_queues_celery_task(self, mock_task, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
"""Test that save operation queues a Celery task without tracking."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
# Verify Celery task was queued with correct parameters
|
||||
mock_task.delay.assert_called_once()
|
||||
call_args = mock_task.delay.call_args[1]
|
||||
|
||||
assert call_args["execution_data"] == sample_workflow_execution.model_dump()
|
||||
assert call_args["tenant_id"] == mock_account.current_tenant_id
|
||||
assert call_args["app_id"] == "test-app"
|
||||
assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN.value
|
||||
assert call_args["creator_user_id"] == mock_account.id
|
||||
|
||||
# Verify no task tracking occurs (no _pending_saves attribute)
|
||||
assert not hasattr(repo, "_pending_saves")
|
||||
|
||||
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_save_handles_celery_failure(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
|
||||
):
|
||||
"""Test that save operation handles Celery task failures."""
|
||||
mock_task.delay.side_effect = Exception("Celery is down")
|
||||
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Celery is down"):
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_save_operation_fire_and_forget(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
|
||||
):
|
||||
"""Test that save operation works in fire-and-forget mode."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
# Test that save doesn't block or maintain state
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
# Verify no pending saves are tracked (no _pending_saves attribute)
|
||||
assert not hasattr(repo, "_pending_saves")
|
||||
|
||||
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_multiple_save_operations(self, mock_task, mock_session_factory, mock_account):
|
||||
"""Test multiple save operations work correctly."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
# Create multiple executions
|
||||
exec1 = WorkflowExecution.new(
|
||||
id_=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input1": "value1"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
exec2 = WorkflowExecution.new(
|
||||
id_=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input2": "value2"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
# Save both executions
|
||||
repo.save(exec1)
|
||||
repo.save(exec2)
|
||||
|
||||
# Should work without issues and not maintain state (no _pending_saves attribute)
|
||||
assert not hasattr(repo, "_pending_saves")
|
||||
|
||||
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_save_with_different_user_types(self, mock_task, mock_session_factory, mock_end_user):
|
||||
"""Test save operation with different user types."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_end_user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
execution = WorkflowExecution.new(
|
||||
id_=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input1": "value1"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
repo.save(execution)
|
||||
|
||||
# Verify task was called with EndUser context
|
||||
mock_task.delay.assert_called_once()
|
||||
call_args = mock_task.delay.call_args[1]
|
||||
assert call_args["tenant_id"] == mock_end_user.tenant_id
|
||||
assert call_args["creator_user_id"] == mock_end_user.id
|
@@ -0,0 +1,349 @@
|
||||
"""
|
||||
Unit tests for CeleryWorkflowNodeExecutionRepository.
|
||||
|
||||
These tests verify the Celery-based asynchronous storage functionality
|
||||
for workflow node execution data.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
|
||||
from models import Account, EndUser
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory():
|
||||
"""Mock SQLAlchemy session factory."""
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# Create a real sessionmaker with in-memory SQLite for testing
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
"""Mock Account user."""
|
||||
account = Mock(spec=Account)
|
||||
account.id = str(uuid4())
|
||||
account.current_tenant_id = str(uuid4())
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_end_user():
|
||||
"""Mock EndUser."""
|
||||
user = Mock(spec=EndUser)
|
||||
user.id = str(uuid4())
|
||||
user.tenant_id = str(uuid4())
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_node_execution():
|
||||
"""Sample WorkflowNodeExecution for testing."""
|
||||
return WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_execution_id=str(uuid4()),
|
||||
index=1,
|
||||
node_id="test_node",
|
||||
node_type=NodeType.START,
|
||||
title="Test Node",
|
||||
inputs={"input1": "value1"},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
|
||||
class TestCeleryWorkflowNodeExecutionRepository:
|
||||
"""Test cases for CeleryWorkflowNodeExecutionRepository."""
|
||||
|
||||
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
|
||||
"""Test repository initialization with sessionmaker."""
|
||||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
assert repo._tenant_id == mock_account.current_tenant_id
|
||||
assert repo._app_id == app_id
|
||||
assert repo._triggered_from == triggered_from
|
||||
assert repo._creator_user_id == mock_account.id
|
||||
assert repo._creator_user_role is not None
|
||||
|
||||
def test_init_with_cache_initialized(self, mock_session_factory, mock_account):
|
||||
"""Test repository initialization with cache properly initialized."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
|
||||
assert repo._execution_cache == {}
|
||||
assert repo._workflow_execution_mapping == {}
|
||||
|
||||
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
|
||||
"""Test repository initialization with EndUser."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_end_user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
assert repo._tenant_id == mock_end_user.tenant_id
|
||||
|
||||
def test_init_without_tenant_id_raises_error(self, mock_session_factory):
|
||||
"""Test that initialization fails without tenant_id."""
|
||||
# Create a mock Account with no tenant_id
|
||||
user = Mock(spec=Account)
|
||||
user.current_tenant_id = None
|
||||
user.id = str(uuid4())
|
||||
|
||||
with pytest.raises(ValueError, match="User must have a tenant_id"):
|
||||
CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_save_caches_and_queues_celery_task(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
|
||||
):
|
||||
"""Test that save operation caches execution and queues a Celery task."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
repo.save(sample_workflow_node_execution)
|
||||
|
||||
# Verify Celery task was queued with correct parameters
|
||||
mock_task.delay.assert_called_once()
|
||||
call_args = mock_task.delay.call_args[1]
|
||||
|
||||
assert call_args["execution_data"] == sample_workflow_node_execution.model_dump()
|
||||
assert call_args["tenant_id"] == mock_account.current_tenant_id
|
||||
assert call_args["app_id"] == "test-app"
|
||||
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
assert call_args["creator_user_id"] == mock_account.id
|
||||
|
||||
# Verify execution is cached
|
||||
assert sample_workflow_node_execution.id in repo._execution_cache
|
||||
assert repo._execution_cache[sample_workflow_node_execution.id] == sample_workflow_node_execution
|
||||
|
||||
# Verify workflow execution mapping is updated
|
||||
assert sample_workflow_node_execution.workflow_execution_id in repo._workflow_execution_mapping
|
||||
assert (
|
||||
sample_workflow_node_execution.id
|
||||
in repo._workflow_execution_mapping[sample_workflow_node_execution.workflow_execution_id]
|
||||
)
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_save_handles_celery_failure(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
|
||||
):
|
||||
"""Test that save operation handles Celery task failures."""
|
||||
mock_task.delay.side_effect = Exception("Celery is down")
|
||||
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Celery is down"):
|
||||
repo.save(sample_workflow_node_execution)
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_get_by_workflow_run_from_cache(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
|
||||
):
|
||||
"""Test that get_by_workflow_run retrieves executions from cache."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Save execution to cache first
|
||||
repo.save(sample_workflow_node_execution)
|
||||
|
||||
workflow_run_id = sample_workflow_node_execution.workflow_execution_id
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="asc")
|
||||
|
||||
result = repo.get_by_workflow_run(workflow_run_id, order_config)
|
||||
|
||||
# Verify results were retrieved from cache
|
||||
assert len(result) == 1
|
||||
assert result[0].id == sample_workflow_node_execution.id
|
||||
assert result[0] is sample_workflow_node_execution
|
||||
|
||||
def test_get_by_workflow_run_without_order_config(self, mock_session_factory, mock_account):
|
||||
"""Test get_by_workflow_run without order configuration."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
result = repo.get_by_workflow_run("workflow-run-id")
|
||||
|
||||
# Should return empty list since nothing in cache
|
||||
assert len(result) == 0
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_cache_operations(self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution):
|
||||
"""Test cache operations work correctly."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Test saving to cache
|
||||
repo.save(sample_workflow_node_execution)
|
||||
|
||||
# Verify cache contains the execution
|
||||
assert sample_workflow_node_execution.id in repo._execution_cache
|
||||
|
||||
# Test retrieving from cache
|
||||
result = repo.get_by_workflow_run(sample_workflow_node_execution.workflow_execution_id)
|
||||
assert len(result) == 1
|
||||
assert result[0].id == sample_workflow_node_execution.id
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_multiple_executions_same_workflow(self, mock_task, mock_session_factory, mock_account):
|
||||
"""Test multiple executions for the same workflow."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Create multiple executions for the same workflow
|
||||
workflow_run_id = str(uuid4())
|
||||
exec1 = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_execution_id=workflow_run_id,
|
||||
index=1,
|
||||
node_id="node1",
|
||||
node_type=NodeType.START,
|
||||
title="Node 1",
|
||||
inputs={"input1": "value1"},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
exec2 = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_execution_id=workflow_run_id,
|
||||
index=2,
|
||||
node_id="node2",
|
||||
node_type=NodeType.LLM,
|
||||
title="Node 2",
|
||||
inputs={"input2": "value2"},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
# Save both executions
|
||||
repo.save(exec1)
|
||||
repo.save(exec2)
|
||||
|
||||
# Verify both are cached and mapped
|
||||
assert len(repo._execution_cache) == 2
|
||||
assert len(repo._workflow_execution_mapping[workflow_run_id]) == 2
|
||||
|
||||
# Test retrieval
|
||||
result = repo.get_by_workflow_run(workflow_run_id)
|
||||
assert len(result) == 2
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_ordering_functionality(self, mock_task, mock_session_factory, mock_account):
|
||||
"""Test ordering functionality works correctly."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Create executions with different indices
|
||||
workflow_run_id = str(uuid4())
|
||||
exec1 = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_execution_id=workflow_run_id,
|
||||
index=2,
|
||||
node_id="node2",
|
||||
node_type=NodeType.START,
|
||||
title="Node 2",
|
||||
inputs={},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
exec2 = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_execution_id=workflow_run_id,
|
||||
index=1,
|
||||
node_id="node1",
|
||||
node_type=NodeType.LLM,
|
||||
title="Node 1",
|
||||
inputs={},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
# Save in random order
|
||||
repo.save(exec1)
|
||||
repo.save(exec2)
|
||||
|
||||
# Test ascending order
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="asc")
|
||||
result = repo.get_by_workflow_run(workflow_run_id, order_config)
|
||||
assert len(result) == 2
|
||||
assert result[0].index == 1
|
||||
assert result[1].index == 2
|
||||
|
||||
# Test descending order
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||
result = repo.get_by_workflow_run(workflow_run_id, order_config)
|
||||
assert len(result) == 2
|
||||
assert result[0].index == 2
|
||||
assert result[1].index == 1
|
@@ -59,7 +59,7 @@ class TestRepositoryFactory:
|
||||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
# Create a mock interface with the same methods
|
||||
# Create a mock interface class
|
||||
class MockInterface:
|
||||
def save(self):
|
||||
pass
|
||||
@@ -67,20 +67,20 @@ class TestRepositoryFactory:
|
||||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
# Should not raise an exception
|
||||
# Should not raise an exception when all methods are present
|
||||
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
|
||||
|
||||
def test_validate_repository_interface_missing_methods(self):
|
||||
"""Test interface validation with missing methods."""
|
||||
|
||||
# Create a mock class that doesn't implement all required methods
|
||||
# Create a mock class that's missing required methods
|
||||
class IncompleteRepository:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
# Missing get_by_id method
|
||||
|
||||
# Create a mock interface with required methods
|
||||
# Create a mock interface that requires both methods
|
||||
class MockInterface:
|
||||
def save(self):
|
||||
pass
|
||||
@@ -88,57 +88,39 @@ class TestRepositoryFactory:
|
||||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
def missing_method(self):
|
||||
pass
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface)
|
||||
assert "does not implement required methods" in str(exc_info.value)
|
||||
assert "get_by_id" in str(exc_info.value)
|
||||
|
||||
def test_validate_constructor_signature_success(self):
|
||||
"""Test successful constructor signature validation."""
|
||||
def test_validate_repository_interface_with_private_methods(self):
|
||||
"""Test that private methods are ignored during interface validation."""
|
||||
|
||||
class MockRepository:
|
||||
def __init__(self, session_factory, user, app_id, triggered_from):
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
# Should not raise an exception
|
||||
DifyCoreRepositoryFactory._validate_constructor_signature(
|
||||
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
|
||||
def test_validate_constructor_signature_missing_params(self):
|
||||
"""Test constructor validation with missing parameters."""
|
||||
|
||||
class IncompleteRepository:
|
||||
def __init__(self, session_factory, user):
|
||||
# Missing app_id and triggered_from parameters
|
||||
def _private_method(self):
|
||||
pass
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._validate_constructor_signature(
|
||||
IncompleteRepository, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
assert "does not accept required parameters" in str(exc_info.value)
|
||||
assert "app_id" in str(exc_info.value)
|
||||
assert "triggered_from" in str(exc_info.value)
|
||||
|
||||
def test_validate_constructor_signature_inspection_error(self, mocker: MockerFixture):
|
||||
"""Test constructor validation when inspection fails."""
|
||||
# Mock inspect.signature to raise an exception
|
||||
mocker.patch("inspect.signature", side_effect=Exception("Inspection failed"))
|
||||
|
||||
class MockRepository:
|
||||
def __init__(self, session_factory):
|
||||
# Create a mock interface with private methods
|
||||
class MockInterface:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"])
|
||||
assert "Failed to validate constructor signature" in str(exc_info.value)
|
||||
def _private_method(self):
|
||||
pass
|
||||
|
||||
# Should not raise exception - private methods should be ignored
|
||||
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_execution_repository_success(self, mock_config, mocker: MockerFixture):
|
||||
"""Test successful creation of WorkflowExecutionRepository."""
|
||||
def test_create_workflow_execution_repository_success(self, mock_config):
|
||||
"""Test successful WorkflowExecutionRepository creation."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
# Create mock dependencies
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
@@ -146,7 +128,7 @@ class TestRepositoryFactory:
|
||||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
||||
# Mock the imported class to be a valid repository
|
||||
# Create mock repository class and instance
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
|
||||
mock_repository_class.return_value = mock_repository_instance
|
||||
@@ -155,7 +137,6 @@ class TestRepositoryFactory:
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
|
||||
):
|
||||
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
@@ -177,7 +158,7 @@ class TestRepositoryFactory:
|
||||
def test_create_workflow_execution_repository_import_error(self, mock_config):
|
||||
"""Test WorkflowExecutionRepository creation with import error."""
|
||||
# Setup mock configuration with invalid class path
|
||||
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
|
||||
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
@@ -195,45 +176,46 @@ class TestRepositoryFactory:
|
||||
def test_create_workflow_execution_repository_validation_error(self, mock_config, mocker: MockerFixture):
|
||||
"""Test WorkflowExecutionRepository creation with validation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
|
||||
# Mock import to succeed but validation to fail
|
||||
# Mock the import to succeed but validation to fail
|
||||
mock_repository_class = MagicMock()
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(
|
||||
DifyCoreRepositoryFactory,
|
||||
"_validate_repository_interface",
|
||||
side_effect=RepositoryImportError("Interface validation failed"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
assert "Interface validation failed" in str(exc_info.value)
|
||||
mocker.patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class)
|
||||
mocker.patch.object(
|
||||
DifyCoreRepositoryFactory,
|
||||
"_validate_repository_interface",
|
||||
side_effect=RepositoryImportError("Interface validation failed"),
|
||||
)
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
assert "Interface validation failed" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_execution_repository_instantiation_error(self, mock_config, mocker: MockerFixture):
|
||||
def test_create_workflow_execution_repository_instantiation_error(self, mock_config):
|
||||
"""Test WorkflowExecutionRepository creation with instantiation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
|
||||
# Mock import and validation to succeed but instantiation to fail
|
||||
mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
|
||||
# Create a mock repository class that raises exception on instantiation
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_class.side_effect = Exception("Instantiation failed")
|
||||
|
||||
# Mock the validation methods to succeed
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
|
||||
):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
@@ -245,18 +227,18 @@ class TestRepositoryFactory:
|
||||
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_success(self, mock_config, mocker: MockerFixture):
|
||||
"""Test successful creation of WorkflowNodeExecutionRepository."""
|
||||
def test_create_workflow_node_execution_repository_success(self, mock_config):
|
||||
"""Test successful WorkflowNodeExecutionRepository creation."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
# Create mock dependencies
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP
|
||||
|
||||
# Mock the imported class to be a valid repository
|
||||
# Create mock repository class and instance
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository)
|
||||
mock_repository_class.return_value = mock_repository_instance
|
||||
@@ -265,7 +247,6 @@ class TestRepositoryFactory:
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
|
||||
):
|
||||
result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
@@ -287,7 +268,7 @@ class TestRepositoryFactory:
|
||||
def test_create_workflow_node_execution_repository_import_error(self, mock_config):
|
||||
"""Test WorkflowNodeExecutionRepository creation with import error."""
|
||||
# Setup mock configuration with invalid class path
|
||||
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
|
||||
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
@@ -297,28 +278,83 @@ class TestRepositoryFactory:
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
assert "Cannot import repository class" in str(exc_info.value)
|
||||
|
||||
def test_repository_import_error_exception(self):
|
||||
"""Test RepositoryImportError exception."""
|
||||
error_message = "Test error message"
|
||||
exception = RepositoryImportError(error_message)
|
||||
assert str(exception) == error_message
|
||||
assert isinstance(exception, Exception)
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_validation_error(self, mock_config, mocker: MockerFixture):
|
||||
"""Test WorkflowNodeExecutionRepository creation with validation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
|
||||
# Mock the import to succeed but validation to fail
|
||||
mock_repository_class = MagicMock()
|
||||
mocker.patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class)
|
||||
mocker.patch.object(
|
||||
DifyCoreRepositoryFactory,
|
||||
"_validate_repository_interface",
|
||||
side_effect=RepositoryImportError("Interface validation failed"),
|
||||
)
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
assert "Interface validation failed" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_with_engine_instead_of_sessionmaker(self, mock_config, mocker: MockerFixture):
|
||||
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
|
||||
"""Test WorkflowNodeExecutionRepository creation with instantiation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
|
||||
# Create a mock repository class that raises exception on instantiation
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_class.side_effect = Exception("Instantiation failed")
|
||||
|
||||
# Mock the validation methods to succeed
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
|
||||
|
||||
def test_repository_import_error_exception(self):
|
||||
"""Test RepositoryImportError exception handling."""
|
||||
error_message = "Custom error message"
|
||||
error = RepositoryImportError(error_message)
|
||||
assert str(error) == error_message
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_with_engine_instead_of_sessionmaker(self, mock_config):
|
||||
"""Test repository creation with Engine instead of sessionmaker."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
# Create mock dependencies with Engine instead of sessionmaker
|
||||
# Create mock dependencies using Engine instead of sessionmaker
|
||||
mock_engine = MagicMock(spec=Engine)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
||||
# Mock the imported class to be a valid repository
|
||||
# Create mock repository class and instance
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
|
||||
mock_repository_class.return_value = mock_repository_instance
|
||||
@@ -327,129 +363,19 @@ class TestRepositoryFactory:
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
|
||||
):
|
||||
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_engine, # Using Engine instead of sessionmaker
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
# Verify the repository was created with the Engine
|
||||
# Verify the repository was created with correct parameters
|
||||
mock_repository_class.assert_called_once_with(
|
||||
session_factory=mock_engine,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
assert result is mock_repository_instance
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_validation_error(self, mock_config):
|
||||
"""Test WorkflowNodeExecutionRepository creation with validation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
|
||||
# Mock import to succeed but validation to fail
|
||||
mock_repository_class = MagicMock()
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(
|
||||
DifyCoreRepositoryFactory,
|
||||
"_validate_repository_interface",
|
||||
side_effect=RepositoryImportError("Interface validation failed"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
assert "Interface validation failed" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
|
||||
"""Test WorkflowNodeExecutionRepository creation with instantiation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
|
||||
# Mock import and validation to succeed but instantiation to fail
|
||||
mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
|
||||
):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
|
||||
|
||||
def test_validate_repository_interface_with_private_methods(self):
|
||||
"""Test interface validation ignores private methods."""
|
||||
|
||||
# Create a mock class with private methods
|
||||
class MockRepository:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
def _private_method(self):
|
||||
pass
|
||||
|
||||
# Create a mock interface with private methods
|
||||
class MockInterface:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
def _private_method(self):
|
||||
pass
|
||||
|
||||
# Should not raise an exception (private methods are ignored)
|
||||
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
|
||||
|
||||
def test_validate_constructor_signature_with_extra_params(self):
|
||||
"""Test constructor validation with extra parameters (should pass)."""
|
||||
|
||||
class MockRepository:
|
||||
def __init__(self, session_factory, user, app_id, triggered_from, extra_param=None):
|
||||
pass
|
||||
|
||||
# Should not raise an exception (extra parameters are allowed)
|
||||
DifyCoreRepositoryFactory._validate_constructor_signature(
|
||||
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
|
||||
def test_validate_constructor_signature_with_kwargs(self):
|
||||
"""Test constructor validation with **kwargs (current implementation doesn't support this)."""
|
||||
|
||||
class MockRepository:
|
||||
def __init__(self, session_factory, user, **kwargs):
|
||||
pass
|
||||
|
||||
# Current implementation doesn't handle **kwargs, so this should raise an exception
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._validate_constructor_signature(
|
||||
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
assert "does not accept required parameters" in str(exc_info.value)
|
||||
assert "app_id" in str(exc_info.value)
|
||||
assert "triggered_from" in str(exc_info.value)
|
||||
|
@@ -8,4 +8,4 @@ cd "$SCRIPT_DIR/.."
|
||||
|
||||
uv --directory api run \
|
||||
celery -A app.celery worker \
|
||||
-P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion
|
||||
-P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage
|
||||
|
@@ -861,17 +861,23 @@ WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
|
||||
|
||||
# Repository configuration
|
||||
# Core workflow execution repository implementation
|
||||
# Options:
|
||||
# - core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository (default)
|
||||
# - core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository
|
||||
CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
|
||||
|
||||
# Core workflow node execution repository implementation
|
||||
# Options:
|
||||
# - core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository (default)
|
||||
# - core.repositories.celery_workflow_node_execution_repository.CeleryWorkflowNodeExecutionRepository
|
||||
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
# API workflow node execution repository implementation
|
||||
API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
# API workflow run repository implementation
|
||||
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
|
||||
|
||||
# API workflow node execution repository implementation
|
||||
API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
# HTTP request node in workflow configuration
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
|
||||
|
@@ -390,8 +390,8 @@ x-shared-env: &shared-api-worker-env
|
||||
WORKFLOW_NODE_EXECUTION_STORAGE: ${WORKFLOW_NODE_EXECUTION_STORAGE:-rdbms}
|
||||
CORE_WORKFLOW_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository}
|
||||
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository}
|
||||
API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository}
|
||||
API_WORKFLOW_RUN_REPOSITORY: ${API_WORKFLOW_RUN_REPOSITORY:-repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository}
|
||||
API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository}
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760}
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576}
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True}
|
||||
|
Reference in New Issue
Block a user