feat(workflow_cycle_manager): Removes redundant repository methods and adds caching (#22597)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2025-07-18 09:26:05 +08:00
committed by GitHub
parent 3826b57424
commit b88dd17fc1
7 changed files with 294 additions and 535 deletions

View File

@@ -6,7 +6,6 @@ import json
import logging
from typing import Optional, Union
from sqlalchemy import select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
@@ -206,44 +205,3 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
# Update the in-memory cache for faster subsequent lookups
logger.debug(f"Updating cache for execution_id: {db_model.id}")
self._execution_cache[db_model.id] = db_model
def get(self, execution_id: str) -> Optional[WorkflowExecution]:
"""
Retrieve a WorkflowExecution by its ID.
First checks the in-memory cache, and if not found, queries the database.
If found in the database, adds it to the cache for future lookups.
Args:
execution_id: The workflow execution ID
Returns:
The WorkflowExecution instance if found, None otherwise
"""
# First check the cache
if execution_id in self._execution_cache:
logger.debug(f"Cache hit for execution_id: {execution_id}")
# Convert cached DB model to domain model
cached_db_model = self._execution_cache[execution_id]
return self._to_domain_model(cached_db_model)
# If not in cache, query the database
logger.debug(f"Cache miss for execution_id: {execution_id}, querying database")
with self._session_factory() as session:
stmt = select(WorkflowRun).where(
WorkflowRun.id == execution_id,
WorkflowRun.tenant_id == self._tenant_id,
)
if self._app_id:
stmt = stmt.where(WorkflowRun.app_id == self._app_id)
db_model = session.scalar(stmt)
if db_model:
# Add DB model to cache
self._execution_cache[execution_id] = db_model
# Convert to domain model and return
return self._to_domain_model(db_model)
return None

View File

@@ -7,7 +7,7 @@ import logging
from collections.abc import Sequence
from typing import Optional, Union
from sqlalchemy import UnaryExpression, asc, delete, desc, select
from sqlalchemy import UnaryExpression, asc, desc, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
@@ -218,47 +218,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}")
self._node_execution_cache[db_model.node_execution_id] = db_model
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
"""
Retrieve a NodeExecution by its node_execution_id.
First checks the in-memory cache, and if not found, queries the database.
If found in the database, adds it to the cache for future lookups.
Args:
node_execution_id: The node execution ID
Returns:
The NodeExecution instance if found, None otherwise
"""
# First check the cache
if node_execution_id in self._node_execution_cache:
logger.debug(f"Cache hit for node_execution_id: {node_execution_id}")
# Convert cached DB model to domain model
cached_db_model = self._node_execution_cache[node_execution_id]
return self._to_domain_model(cached_db_model)
# If not in cache, query the database
logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database")
with self._session_factory() as session:
stmt = select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.node_execution_id == node_execution_id,
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
db_model = session.scalar(stmt)
if db_model:
# Add DB model to cache
self._node_execution_cache[node_execution_id] = db_model
# Convert to domain model and return
return self._to_domain_model(db_model)
return None
def get_db_models_by_workflow_run(
self,
workflow_run_id: str,
@@ -344,68 +303,3 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
domain_models.append(domain_model)
return domain_models
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all running NodeExecution instances for a specific workflow run.
This method queries the database directly and updates the cache with any
retrieved executions that have a node_execution_id.
Args:
workflow_run_id: The workflow run ID
Returns:
A list of running NodeExecution instances
"""
with self._session_factory() as session:
stmt = select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
WorkflowNodeExecutionModel.status == WorkflowNodeExecutionStatus.RUNNING,
WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
db_models = session.scalars(stmt).all()
domain_models = []
for model in db_models:
# Update cache if node_execution_id is present
if model.node_execution_id:
self._node_execution_cache[model.node_execution_id] = model
# Convert to domain model
domain_model = self._to_domain_model(model)
domain_models.append(domain_model)
return domain_models
def clear(self) -> None:
"""
Clear all WorkflowNodeExecution records for the current tenant_id and app_id.
This method deletes all WorkflowNodeExecution records that match the tenant_id
and app_id (if provided) associated with this repository instance.
It also clears the in-memory cache.
"""
with self._session_factory() as session:
stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.tenant_id == self._tenant_id)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
result = session.execute(stmt)
session.commit()
deleted_count = result.rowcount
logger.info(
f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}"
+ (f" and app {self._app_id}" if self._app_id else "")
)
# Clear the in-memory cache
self._node_execution_cache.clear()
logger.info("Cleared in-memory node execution cache")

View File

@@ -1,4 +1,4 @@
from typing import Optional, Protocol
from typing import Protocol
from core.workflow.entities.workflow_execution import WorkflowExecution
@@ -28,15 +28,3 @@ class WorkflowExecutionRepository(Protocol):
execution: The WorkflowExecution instance to save or update
"""
...
def get(self, execution_id: str) -> Optional[WorkflowExecution]:
"""
Retrieve a WorkflowExecution by its ID.
Args:
execution_id: The workflow execution ID
Returns:
The WorkflowExecution instance if found, None otherwise
"""
...

View File

@@ -39,18 +39,6 @@ class WorkflowNodeExecutionRepository(Protocol):
"""
...
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
"""
Retrieve a NodeExecution by its node_execution_id.
Args:
node_execution_id: The node execution ID
Returns:
The NodeExecution instance if found, None otherwise
"""
...
def get_by_workflow_run(
self,
workflow_run_id: str,
@@ -69,24 +57,3 @@ class WorkflowNodeExecutionRepository(Protocol):
A list of NodeExecution instances
"""
...
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all running NodeExecution instances for a specific workflow run.
Args:
workflow_run_id: The workflow run ID
Returns:
A list of running NodeExecution instances
"""
...
def clear(self) -> None:
"""
Clear all NodeExecution records based on implementation-specific criteria.
This method is intended to be used for bulk deletion operations, such as removing
all records associated with a specific app_id and tenant_id in multi-tenant implementations.
"""
...

View File

@@ -55,24 +55,15 @@ class WorkflowCycleManager:
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
# Initialize caches for workflow execution cycle
# These caches avoid redundant repository calls during a single workflow execution
self._workflow_execution_cache: dict[str, WorkflowExecution] = {}
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
def handle_workflow_run_start(self) -> WorkflowExecution:
inputs = {**self._application_generate_entity.inputs}
inputs = self._prepare_workflow_inputs()
execution_id = self._get_or_generate_execution_id()
# Iterate over SystemVariable fields using Pydantic's model_fields
if self._workflow_system_variables:
for field_name, value in self._workflow_system_variables.to_dict().items():
if field_name == SystemVariableKey.CONVERSATION_ID:
continue
inputs[f"sys.{field_name}"] = value
# handle special values
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
# init workflow run
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
execution_id = str(
self._workflow_system_variables.workflow_execution_id if self._workflow_system_variables else None
) or str(uuid4())
execution = WorkflowExecution.new(
id_=execution_id,
workflow_id=self._workflow_info.workflow_id,
@@ -83,9 +74,7 @@ class WorkflowCycleManager:
started_at=datetime.now(UTC).replace(tzinfo=None),
)
self._workflow_execution_repository.save(execution)
return execution
return self._save_and_cache_workflow_execution(execution)
def handle_workflow_run_success(
self,
@@ -99,24 +88,16 @@ class WorkflowCycleManager:
) -> WorkflowExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
# outputs = WorkflowEntry.handle_special_values(outputs)
workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED
workflow_execution.outputs = outputs or {}
workflow_execution.total_tokens = total_tokens
workflow_execution.total_steps = total_steps
workflow_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_execution=workflow_execution,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
self._update_workflow_execution_completion(
workflow_execution,
status=WorkflowExecutionStatus.SUCCEEDED,
outputs=outputs,
total_tokens=total_tokens,
total_steps=total_steps,
)
self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id)
self._workflow_execution_repository.save(workflow_execution)
return workflow_execution
@@ -132,24 +113,17 @@ class WorkflowCycleManager:
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowExecution:
execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
# outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
execution.outputs = outputs or {}
execution.total_tokens = total_tokens
execution.total_steps = total_steps
execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
execution.exceptions_count = exceptions_count
self._update_workflow_execution_completion(
execution,
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
outputs=outputs,
total_tokens=total_tokens,
total_steps=total_steps,
exceptions_count=exceptions_count,
)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_execution=execution,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
self._add_trace_task_if_needed(trace_manager, execution, conversation_id)
self._workflow_execution_repository.save(execution)
return execution
@@ -169,39 +143,18 @@ class WorkflowCycleManager:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
now = naive_utc_now()
workflow_execution.status = WorkflowExecutionStatus(status.value)
workflow_execution.error_message = error_message
workflow_execution.total_tokens = total_tokens
workflow_execution.total_steps = total_steps
workflow_execution.finished_at = now
workflow_execution.exceptions_count = exceptions_count
# Use the instance repository to find running executions for a workflow run
running_node_executions = self._workflow_node_execution_repository.get_running_executions(
workflow_run_id=workflow_execution.id_
self._update_workflow_execution_completion(
workflow_execution,
status=status,
total_tokens=total_tokens,
total_steps=total_steps,
error_message=error_message,
exceptions_count=exceptions_count,
finished_at=now,
)
# Update the domain models
for node_execution in running_node_executions:
if node_execution.node_execution_id:
# Update the domain model
node_execution.status = WorkflowNodeExecutionStatus.FAILED
node_execution.error = error_message
node_execution.finished_at = now
node_execution.elapsed_time = (now - node_execution.created_at).total_seconds()
# Update the repository with the domain model
self._workflow_node_execution_repository.save(node_execution)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_execution=workflow_execution,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
self._fail_running_node_executions(workflow_execution.id_, error_message, now)
self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id)
self._workflow_execution_repository.save(workflow_execution)
return workflow_execution
@@ -214,8 +167,198 @@ class WorkflowCycleManager:
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
# Create a domain model
created_at = datetime.now(UTC).replace(tzinfo=None)
domain_execution = self._create_node_execution_from_event(
workflow_execution=workflow_execution,
event=event,
status=WorkflowNodeExecutionStatus.RUNNING,
)
return self._save_and_cache_node_execution(domain_execution)
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
self._update_node_execution_completion(
domain_execution,
event=event,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
)
self._workflow_node_execution_repository.save(domain_execution)
return domain_execution
def handle_workflow_node_execution_failed(
self,
*,
event: QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
:return:
"""
domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
status = (
WorkflowNodeExecutionStatus.EXCEPTION
if isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.FAILED
)
self._update_node_execution_completion(
domain_execution,
event=event,
status=status,
error=event.error,
handle_special_values=True,
)
self._workflow_node_execution_repository.save(domain_execution)
return domain_execution
def handle_workflow_node_execution_retried(
self, *, workflow_execution_id: str, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
domain_execution = self._create_node_execution_from_event(
workflow_execution=workflow_execution,
event=event,
status=WorkflowNodeExecutionStatus.RETRY,
error=event.error,
created_at=event.start_at,
)
# Handle inputs and outputs
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = event.outputs
metadata = self._merge_event_metadata(event)
domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=metadata)
return self._save_and_cache_node_execution(domain_execution)
def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution:
# Check cache first
if id in self._workflow_execution_cache:
return self._workflow_execution_cache[id]
raise WorkflowRunNotFoundError(id)
def _prepare_workflow_inputs(self) -> dict[str, Any]:
"""Prepare workflow inputs by merging application inputs with system variables."""
inputs = {**self._application_generate_entity.inputs}
if self._workflow_system_variables:
for field_name, value in self._workflow_system_variables.to_dict().items():
if field_name != SystemVariableKey.CONVERSATION_ID:
inputs[f"sys.{field_name}"] = value
return dict(WorkflowEntry.handle_special_values(inputs) or {})
def _get_or_generate_execution_id(self) -> str:
"""Get execution ID from system variables or generate a new one."""
if self._workflow_system_variables and self._workflow_system_variables.workflow_execution_id:
return str(self._workflow_system_variables.workflow_execution_id)
return str(uuid4())
def _save_and_cache_workflow_execution(self, execution: WorkflowExecution) -> WorkflowExecution:
"""Save workflow execution to repository and cache it."""
self._workflow_execution_repository.save(execution)
self._workflow_execution_cache[execution.id_] = execution
return execution
def _save_and_cache_node_execution(self, execution: WorkflowNodeExecution) -> WorkflowNodeExecution:
"""Save node execution to repository and cache it if it has an ID."""
self._workflow_node_execution_repository.save(execution)
if execution.node_execution_id:
self._node_execution_cache[execution.node_execution_id] = execution
return execution
def _get_node_execution_from_cache(self, node_execution_id: str) -> WorkflowNodeExecution:
"""Get node execution from cache or raise error if not found."""
domain_execution = self._node_execution_cache.get(node_execution_id)
if not domain_execution:
raise ValueError(f"Domain node execution not found: {node_execution_id}")
return domain_execution
def _update_workflow_execution_completion(
self,
execution: WorkflowExecution,
*,
status: WorkflowExecutionStatus,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
error_message: Optional[str] = None,
exceptions_count: int = 0,
finished_at: Optional[datetime] = None,
) -> None:
"""Update workflow execution with completion data."""
execution.status = status
execution.outputs = outputs or {}
execution.total_tokens = total_tokens
execution.total_steps = total_steps
execution.finished_at = finished_at or naive_utc_now()
execution.exceptions_count = exceptions_count
if error_message:
execution.error_message = error_message
def _add_trace_task_if_needed(
self,
trace_manager: Optional[TraceQueueManager],
workflow_execution: WorkflowExecution,
conversation_id: Optional[str],
) -> None:
"""Add trace task if trace manager is provided."""
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_execution=workflow_execution,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
def _fail_running_node_executions(
self,
workflow_execution_id: str,
error_message: str,
now: datetime,
) -> None:
"""Fail all running node executions for a workflow."""
running_node_executions = [
node_exec
for node_exec in self._node_execution_cache.values()
if node_exec.workflow_execution_id == workflow_execution_id
and node_exec.status == WorkflowNodeExecutionStatus.RUNNING
]
for node_execution in running_node_executions:
if node_execution.node_execution_id:
node_execution.status = WorkflowNodeExecutionStatus.FAILED
node_execution.error = error_message
node_execution.finished_at = now
node_execution.elapsed_time = (now - node_execution.created_at).total_seconds()
self._workflow_node_execution_repository.save(node_execution)
def _create_node_execution_from_event(
self,
*,
workflow_execution: WorkflowExecution,
event: Union[QueueNodeStartedEvent, QueueNodeRetryEvent],
status: WorkflowNodeExecutionStatus,
error: Optional[str] = None,
created_at: Optional[datetime] = None,
) -> WorkflowNodeExecution:
"""Create a node execution from an event."""
now = datetime.now(UTC).replace(tzinfo=None)
created_at = created_at or now
metadata = {
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
@@ -232,152 +375,76 @@ class WorkflowCycleManager:
node_id=event.node_id,
node_type=event.node_type,
title=event.node_data.title,
status=WorkflowNodeExecutionStatus.RUNNING,
status=status,
metadata=metadata,
created_at=created_at,
error=error,
)
# Use the instance repository to save the domain model
self._workflow_node_execution_repository.save(domain_execution)
if status == WorkflowNodeExecutionStatus.RETRY:
domain_execution.finished_at = now
domain_execution.elapsed_time = (now - created_at).total_seconds()
return domain_execution
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
# Get the domain model from repository
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
if not domain_execution:
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
# Process data
inputs = event.inputs
process_data = event.process_data
outputs = event.outputs
# Convert metadata keys to strings
execution_metadata_dict = {}
if event.execution_metadata:
for key, value in event.execution_metadata.items():
execution_metadata_dict[key] = value
def _update_node_execution_completion(
self,
domain_execution: WorkflowNodeExecution,
*,
event: Union[
QueueNodeSucceededEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeExceptionEvent,
],
status: WorkflowNodeExecutionStatus,
error: Optional[str] = None,
handle_special_values: bool = False,
) -> None:
"""Update node execution with completion data."""
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
# Update domain model
domain_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
domain_execution.update_from_mapping(
inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
)
domain_execution.finished_at = finished_at
domain_execution.elapsed_time = elapsed_time
# Update the repository with the domain model
self._workflow_node_execution_repository.save(domain_execution)
return domain_execution
def handle_workflow_node_execution_failed(
self,
*,
event: QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
:return:
"""
# Get the domain model from repository
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
if not domain_execution:
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
# Process data
if handle_special_values:
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
else:
inputs = event.inputs
process_data = event.process_data
outputs = event.outputs
# Convert metadata keys to strings
execution_metadata_dict = {}
# Convert metadata
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, Any] = {}
if event.execution_metadata:
for key, value in event.execution_metadata.items():
execution_metadata_dict[key] = value
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
execution_metadata_dict.update(event.execution_metadata)
# Update domain model
domain_execution.status = (
WorkflowNodeExecutionStatus.FAILED
if not isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.EXCEPTION
)
domain_execution.error = event.error
domain_execution.status = status
domain_execution.update_from_mapping(
inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
inputs=inputs,
process_data=process_data,
outputs=outputs,
metadata=execution_metadata_dict,
)
domain_execution.finished_at = finished_at
domain_execution.elapsed_time = elapsed_time
# Update the repository with the domain model
self._workflow_node_execution_repository.save(domain_execution)
if error:
domain_execution.error = error
return domain_execution
def handle_workflow_node_execution_retried(
self, *, workflow_execution_id: str, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
created_at = event.start_at
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - created_at).total_seconds()
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = event.outputs
# Convert metadata keys to strings
def _merge_event_metadata(self, event: QueueNodeRetryEvent) -> dict[WorkflowNodeExecutionMetadataKey, str | None]:
"""Merge event metadata with origin metadata."""
origin_metadata = {
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
}
# Convert execution metadata keys to strings
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {}
if event.execution_metadata:
for key, value in event.execution_metadata.items():
execution_metadata_dict[key] = value
execution_metadata_dict.update(event.execution_metadata)
merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata
# Create a domain model
domain_execution = WorkflowNodeExecution(
id=str(uuid4()),
workflow_id=workflow_execution.workflow_id,
workflow_execution_id=workflow_execution.id_,
predecessor_node_id=event.predecessor_node_id,
node_execution_id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
title=event.node_data.title,
status=WorkflowNodeExecutionStatus.RETRY,
created_at=created_at,
finished_at=finished_at,
elapsed_time=elapsed_time,
error=event.error,
index=event.node_run_index,
)
# Update with mappings
domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=merged_metadata)
# Use the instance repository to save the domain model
self._workflow_node_execution_repository.save(domain_execution)
return domain_execution
def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution:
execution = self._workflow_execution_repository.get(id)
if not execution:
raise WorkflowRunNotFoundError(id)
return execution
return {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata

View File

@@ -80,15 +80,12 @@ def real_workflow_system_variables():
@pytest.fixture
def mock_node_execution_repository():
repo = MagicMock(spec=WorkflowNodeExecutionRepository)
repo.get_by_node_execution_id.return_value = None
repo.get_running_executions.return_value = []
return repo
@pytest.fixture
def mock_workflow_execution_repository():
repo = MagicMock(spec=WorkflowExecutionRepository)
repo.get.return_value = None
return repo
@@ -217,8 +214,8 @@ def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execu
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
# Pre-populate the cache with the workflow execution
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_run_success(
@@ -251,11 +248,10 @@ def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execut
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
# Pre-populate the cache with the workflow execution
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
# Mock get_running_executions to return an empty list
workflow_cycle_manager._workflow_node_execution_repository.get_running_executions.return_value = []
# No running node executions in cache (empty cache)
# Call the method
result = workflow_cycle_manager.handle_workflow_run_failed(
@@ -289,8 +285,8 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
# Pre-populate the cache with the workflow execution
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
# Create a mock event
event = MagicMock(spec=QueueNodeStartedEvent)
@@ -342,8 +338,8 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock the repository get method to return the real execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
# Pre-populate the cache with the workflow execution
workflow_cycle_manager._workflow_execution_cache["test-workflow-run-id"] = workflow_execution
# Call the method
result = workflow_cycle_manager._get_workflow_execution_or_raise_error("test-workflow-run-id")
@@ -351,11 +347,13 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work
# Verify the result
assert result == workflow_execution
# Test error case
workflow_cycle_manager._workflow_execution_repository.get.return_value = None
# Test error case - clear cache
workflow_cycle_manager._workflow_execution_cache.clear()
# Expect an error when execution is not found
with pytest.raises(ValueError):
from core.app.task_pipeline.exc import WorkflowRunNotFoundError
with pytest.raises(WorkflowRunNotFoundError):
workflow_cycle_manager._get_workflow_execution_or_raise_error("non-existent-id")
@@ -384,8 +382,8 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager):
created_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock the repository to return the node execution
workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution
# Pre-populate the cache with the node execution
workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_node_execution_success(
@@ -414,8 +412,8 @@ def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workfl
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
# Pre-populate the cache with the workflow execution
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_run_partial_success(
@@ -462,8 +460,8 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
created_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock the repository to return the node execution
workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution
# Pre-populate the cache with the node execution
workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_node_execution_failed(

View File

@@ -137,37 +137,6 @@ def test_save_with_existing_tenant_id(repository, session):
session_obj.merge.assert_called_once_with(modified_execution)
def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
"""Test get_by_node_execution_id method."""
session_obj, _ = session
# Set up mock
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
mock_stmt = mocker.MagicMock()
mock_select.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt
# Create a properly configured mock execution
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
configure_mock_execution(mock_execution)
session_obj.scalar.return_value = mock_execution
# Create a mock domain model to be returned by _to_domain_model
mock_domain_model = mocker.MagicMock()
# Mock the _to_domain_model method to return our mock domain model
repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
# Call method
result = repository.get_by_node_execution_id("test-node-execution-id")
# Assert select was called with correct parameters
mock_select.assert_called_once()
session_obj.scalar.assert_called_once_with(mock_stmt)
# Assert _to_domain_model was called with the mock execution
repository._to_domain_model.assert_called_once_with(mock_execution)
# Assert the result is our mock domain model
assert result is mock_domain_model
def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
"""Test get_by_workflow_run method."""
session_obj, _ = session
@@ -202,88 +171,6 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
assert result[0] is mock_domain_model
def test_get_running_executions(repository, session, mocker: MockerFixture):
"""Test get_running_executions method."""
session_obj, _ = session
# Set up mock
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
mock_stmt = mocker.MagicMock()
mock_select.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt
# Create a properly configured mock execution
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
configure_mock_execution(mock_execution)
session_obj.scalars.return_value.all.return_value = [mock_execution]
# Create a mock domain model to be returned by _to_domain_model
mock_domain_model = mocker.MagicMock()
# Mock the _to_domain_model method to return our mock domain model
repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
# Call method
result = repository.get_running_executions("test-workflow-run-id")
# Assert select was called with correct parameters
mock_select.assert_called_once()
session_obj.scalars.assert_called_once_with(mock_stmt)
# Assert _to_domain_model was called with the mock execution
repository._to_domain_model.assert_called_once_with(mock_execution)
# Assert the result contains our mock domain model
assert len(result) == 1
assert result[0] is mock_domain_model
def test_update_via_save(repository, session):
"""Test updating an existing record via save method."""
session_obj, _ = session
# Create a mock execution
execution = MagicMock(spec=WorkflowNodeExecutionModel)
execution.tenant_id = None
execution.app_id = None
execution.inputs = None
execution.process_data = None
execution.outputs = None
execution.metadata = None
# Mock the to_db_model method to return the execution itself
# This simulates the behavior of setting tenant_id and app_id
repository.to_db_model = MagicMock(return_value=execution)
# Call save method to update an existing record
repository.save(execution)
# Assert to_db_model was called with the execution
repository.to_db_model.assert_called_once_with(execution)
# Assert session.merge was called (for updates)
session_obj.merge.assert_called_once_with(execution)
def test_clear(repository, session, mocker: MockerFixture):
"""Test clear method."""
session_obj, _ = session
# Set up mock
mock_delete = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.delete")
mock_stmt = mocker.MagicMock()
mock_delete.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt
# Mock the execute result with rowcount
mock_result = mocker.MagicMock()
mock_result.rowcount = 5 # Simulate 5 records deleted
session_obj.execute.return_value = mock_result
# Call method
repository.clear()
# Assert delete was called with correct parameters
mock_delete.assert_called_once_with(WorkflowNodeExecutionModel)
mock_stmt.where.assert_called()
session_obj.execute.assert_called_once_with(mock_stmt)
session_obj.commit.assert_called_once()
def test_to_db_model(repository):
"""Test to_db_model method."""
# Create a domain model