From b88dd17fc1003e9e6d135e2c4885a79fb6f75853 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 18 Jul 2025 09:26:05 +0800 Subject: [PATCH] feat(workflow_cycle_manager): Removes redundant repository methods and adds caching (#22597) Signed-off-by: -LAN- --- ...qlalchemy_workflow_execution_repository.py | 42 -- ...hemy_workflow_node_execution_repository.py | 108 +--- .../workflow_execution_repository.py | 14 +- .../workflow_node_execution_repository.py | 33 -- api/core/workflow/workflow_cycle_manager.py | 477 ++++++++++-------- .../workflow/test_workflow_cycle_manager.py | 42 +- .../test_sqlalchemy_repository.py | 113 ----- 7 files changed, 294 insertions(+), 535 deletions(-) diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 0b3e5eb42..c579ff402 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -6,7 +6,6 @@ import json import logging from typing import Optional, Union -from sqlalchemy import select from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -206,44 +205,3 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): # Update the in-memory cache for faster subsequent lookups logger.debug(f"Updating cache for execution_id: {db_model.id}") self._execution_cache[db_model.id] = db_model - - def get(self, execution_id: str) -> Optional[WorkflowExecution]: - """ - Retrieve a WorkflowExecution by its ID. - - First checks the in-memory cache, and if not found, queries the database. - If found in the database, adds it to the cache for future lookups. - - Args: - execution_id: The workflow execution ID - - Returns: - The WorkflowExecution instance if found, None otherwise - """ - # First check the cache - if execution_id in self._execution_cache: - logger.debug(f"Cache hit for execution_id: {execution_id}") - # Convert cached DB model to domain model - cached_db_model = self._execution_cache[execution_id] - return self._to_domain_model(cached_db_model) - - # If not in cache, query the database - logger.debug(f"Cache miss for execution_id: {execution_id}, querying database") - with self._session_factory() as session: - stmt = select(WorkflowRun).where( - WorkflowRun.id == execution_id, - WorkflowRun.tenant_id == self._tenant_id, - ) - - if self._app_id: - stmt = stmt.where(WorkflowRun.app_id == self._app_id) - - db_model = session.scalar(stmt) - if db_model: - # Add DB model to cache - self._execution_cache[execution_id] = db_model - - # Convert to domain model and return - return self._to_domain_model(db_model) - - return None diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index a5feeb0d7..d4a31390f 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -7,7 +7,7 @@ import logging from collections.abc import Sequence from typing import Optional, Union -from sqlalchemy import UnaryExpression, asc, delete, desc, select +from sqlalchemy import UnaryExpression, asc, desc, select from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -218,47 +218,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}") self._node_execution_cache[db_model.node_execution_id] = db_model - def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: - """ - Retrieve a NodeExecution by its node_execution_id. - - First checks the in-memory cache, and if not found, queries the database. - If found in the database, adds it to the cache for future lookups. - - Args: - node_execution_id: The node execution ID - - Returns: - The NodeExecution instance if found, None otherwise - """ - # First check the cache - if node_execution_id in self._node_execution_cache: - logger.debug(f"Cache hit for node_execution_id: {node_execution_id}") - # Convert cached DB model to domain model - cached_db_model = self._node_execution_cache[node_execution_id] - return self._to_domain_model(cached_db_model) - - # If not in cache, query the database - logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database") - with self._session_factory() as session: - stmt = select(WorkflowNodeExecutionModel).where( - WorkflowNodeExecutionModel.node_execution_id == node_execution_id, - WorkflowNodeExecutionModel.tenant_id == self._tenant_id, - ) - - if self._app_id: - stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id) - - db_model = session.scalar(stmt) - if db_model: - # Add DB model to cache - self._node_execution_cache[node_execution_id] = db_model - - # Convert to domain model and return - return self._to_domain_model(db_model) - - return None - def get_db_models_by_workflow_run( self, workflow_run_id: str, @@ -344,68 +303,3 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) domain_models.append(domain_model) return domain_models - - def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: - """ - Retrieve all running NodeExecution instances for a specific workflow run. - - This method queries the database directly and updates the cache with any - retrieved executions that have a node_execution_id. - - Args: - workflow_run_id: The workflow run ID - - Returns: - A list of running NodeExecution instances - """ - with self._session_factory() as session: - stmt = select(WorkflowNodeExecutionModel).where( - WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, - WorkflowNodeExecutionModel.tenant_id == self._tenant_id, - WorkflowNodeExecutionModel.status == WorkflowNodeExecutionStatus.RUNNING, - WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) - - if self._app_id: - stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id) - - db_models = session.scalars(stmt).all() - domain_models = [] - - for model in db_models: - # Update cache if node_execution_id is present - if model.node_execution_id: - self._node_execution_cache[model.node_execution_id] = model - - # Convert to domain model - domain_model = self._to_domain_model(model) - domain_models.append(domain_model) - - return domain_models - - def clear(self) -> None: - """ - Clear all WorkflowNodeExecution records for the current tenant_id and app_id. - - This method deletes all WorkflowNodeExecution records that match the tenant_id - and app_id (if provided) associated with this repository instance. - It also clears the in-memory cache. - """ - with self._session_factory() as session: - stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.tenant_id == self._tenant_id) - - if self._app_id: - stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id) - - result = session.execute(stmt) - session.commit() - - deleted_count = result.rowcount - logger.info( - f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}" - + (f" and app {self._app_id}" if self._app_id else "") - ) - - # Clear the in-memory cache - self._node_execution_cache.clear() - logger.info("Cleared in-memory node execution cache") diff --git a/api/core/workflow/repositories/workflow_execution_repository.py b/api/core/workflow/repositories/workflow_execution_repository.py index 5917310c8..bcbd25339 100644 --- a/api/core/workflow/repositories/workflow_execution_repository.py +++ b/api/core/workflow/repositories/workflow_execution_repository.py @@ -1,4 +1,4 @@ -from typing import Optional, Protocol +from typing import Protocol from core.workflow.entities.workflow_execution import WorkflowExecution @@ -28,15 +28,3 @@ class WorkflowExecutionRepository(Protocol): execution: The WorkflowExecution instance to save or update """ ... - - def get(self, execution_id: str) -> Optional[WorkflowExecution]: - """ - Retrieve a WorkflowExecution by its ID. - - Args: - execution_id: The workflow execution ID - - Returns: - The WorkflowExecution instance if found, None otherwise - """ - ... diff --git a/api/core/workflow/repositories/workflow_node_execution_repository.py b/api/core/workflow/repositories/workflow_node_execution_repository.py index 1908a6b19..8bf81f544 100644 --- a/api/core/workflow/repositories/workflow_node_execution_repository.py +++ b/api/core/workflow/repositories/workflow_node_execution_repository.py @@ -39,18 +39,6 @@ class WorkflowNodeExecutionRepository(Protocol): """ ... - def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: - """ - Retrieve a NodeExecution by its node_execution_id. - - Args: - node_execution_id: The node execution ID - - Returns: - The NodeExecution instance if found, None otherwise - """ - ... - def get_by_workflow_run( self, workflow_run_id: str, @@ -69,24 +57,3 @@ class WorkflowNodeExecutionRepository(Protocol): A list of NodeExecution instances """ ... - - def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: - """ - Retrieve all running NodeExecution instances for a specific workflow run. - - Args: - workflow_run_id: The workflow run ID - - Returns: - A list of running NodeExecution instances - """ - ... - - def clear(self) -> None: - """ - Clear all NodeExecution records based on implementation-specific criteria. - - This method is intended to be used for bulk deletion operations, such as removing - all records associated with a specific app_id and tenant_id in multi-tenant implementations. - """ - ... diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index 50ff73397..3e591ef88 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -55,24 +55,15 @@ class WorkflowCycleManager: self._workflow_execution_repository = workflow_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository + # Initialize caches for workflow execution cycle + # These caches avoid redundant repository calls during a single workflow execution + self._workflow_execution_cache: dict[str, WorkflowExecution] = {} + self._node_execution_cache: dict[str, WorkflowNodeExecution] = {} + def handle_workflow_run_start(self) -> WorkflowExecution: - inputs = {**self._application_generate_entity.inputs} + inputs = self._prepare_workflow_inputs() + execution_id = self._get_or_generate_execution_id() - # Iterate over SystemVariable fields using Pydantic's model_fields - if self._workflow_system_variables: - for field_name, value in self._workflow_system_variables.to_dict().items(): - if field_name == SystemVariableKey.CONVERSATION_ID: - continue - inputs[f"sys.{field_name}"] = value - - # handle special values - inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) - - # init workflow run - # TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this - execution_id = str( - self._workflow_system_variables.workflow_execution_id if self._workflow_system_variables else None - ) or str(uuid4()) execution = WorkflowExecution.new( id_=execution_id, workflow_id=self._workflow_info.workflow_id, @@ -83,9 +74,7 @@ class WorkflowCycleManager: started_at=datetime.now(UTC).replace(tzinfo=None), ) - self._workflow_execution_repository.save(execution) - - return execution + return self._save_and_cache_workflow_execution(execution) def handle_workflow_run_success( self, @@ -99,23 +88,15 @@ class WorkflowCycleManager: ) -> WorkflowExecution: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) - # outputs = WorkflowEntry.handle_special_values(outputs) + self._update_workflow_execution_completion( + workflow_execution, + status=WorkflowExecutionStatus.SUCCEEDED, + outputs=outputs, + total_tokens=total_tokens, + total_steps=total_steps, + ) - workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED - workflow_execution.outputs = outputs or {} - workflow_execution.total_tokens = total_tokens - workflow_execution.total_steps = total_steps - workflow_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) - - if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.WORKFLOW_TRACE, - workflow_execution=workflow_execution, - conversation_id=conversation_id, - user_id=trace_manager.user_id, - ) - ) + self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id) self._workflow_execution_repository.save(workflow_execution) return workflow_execution @@ -132,24 +113,17 @@ class WorkflowCycleManager: trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowExecution: execution = self._get_workflow_execution_or_raise_error(workflow_run_id) - # outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) - execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED - execution.outputs = outputs or {} - execution.total_tokens = total_tokens - execution.total_steps = total_steps - execution.finished_at = datetime.now(UTC).replace(tzinfo=None) - execution.exceptions_count = exceptions_count + self._update_workflow_execution_completion( + execution, + status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED, + outputs=outputs, + total_tokens=total_tokens, + total_steps=total_steps, + exceptions_count=exceptions_count, + ) - if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.WORKFLOW_TRACE, - workflow_execution=execution, - conversation_id=conversation_id, - user_id=trace_manager.user_id, - ) - ) + self._add_trace_task_if_needed(trace_manager, execution, conversation_id) self._workflow_execution_repository.save(execution) return execution @@ -169,39 +143,18 @@ class WorkflowCycleManager: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) now = naive_utc_now() - workflow_execution.status = WorkflowExecutionStatus(status.value) - workflow_execution.error_message = error_message - workflow_execution.total_tokens = total_tokens - workflow_execution.total_steps = total_steps - workflow_execution.finished_at = now - workflow_execution.exceptions_count = exceptions_count - - # Use the instance repository to find running executions for a workflow run - running_node_executions = self._workflow_node_execution_repository.get_running_executions( - workflow_run_id=workflow_execution.id_ + self._update_workflow_execution_completion( + workflow_execution, + status=status, + total_tokens=total_tokens, + total_steps=total_steps, + error_message=error_message, + exceptions_count=exceptions_count, + finished_at=now, ) - # Update the domain models - for node_execution in running_node_executions: - if node_execution.node_execution_id: - # Update the domain model - node_execution.status = WorkflowNodeExecutionStatus.FAILED - node_execution.error = error_message - node_execution.finished_at = now - node_execution.elapsed_time = (now - node_execution.created_at).total_seconds() - - # Update the repository with the domain model - self._workflow_node_execution_repository.save(node_execution) - - if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.WORKFLOW_TRACE, - workflow_execution=workflow_execution, - conversation_id=conversation_id, - user_id=trace_manager.user_id, - ) - ) + self._fail_running_node_executions(workflow_execution.id_, error_message, now) + self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id) self._workflow_execution_repository.save(workflow_execution) return workflow_execution @@ -214,8 +167,198 @@ class WorkflowCycleManager: ) -> WorkflowNodeExecution: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id) - # Create a domain model - created_at = datetime.now(UTC).replace(tzinfo=None) + domain_execution = self._create_node_execution_from_event( + workflow_execution=workflow_execution, + event=event, + status=WorkflowNodeExecutionStatus.RUNNING, + ) + + return self._save_and_cache_node_execution(domain_execution) + + def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: + domain_execution = self._get_node_execution_from_cache(event.node_execution_id) + + self._update_node_execution_completion( + domain_execution, + event=event, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + ) + + self._workflow_node_execution_repository.save(domain_execution) + return domain_execution + + def handle_workflow_node_execution_failed( + self, + *, + event: QueueNodeFailedEvent + | QueueNodeInIterationFailedEvent + | QueueNodeInLoopFailedEvent + | QueueNodeExceptionEvent, + ) -> WorkflowNodeExecution: + """ + Workflow node execution failed + :param event: queue node failed event + :return: + """ + domain_execution = self._get_node_execution_from_cache(event.node_execution_id) + + status = ( + WorkflowNodeExecutionStatus.EXCEPTION + if isinstance(event, QueueNodeExceptionEvent) + else WorkflowNodeExecutionStatus.FAILED + ) + + self._update_node_execution_completion( + domain_execution, + event=event, + status=status, + error=event.error, + handle_special_values=True, + ) + + self._workflow_node_execution_repository.save(domain_execution) + return domain_execution + + def handle_workflow_node_execution_retried( + self, *, workflow_execution_id: str, event: QueueNodeRetryEvent + ) -> WorkflowNodeExecution: + workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id) + + domain_execution = self._create_node_execution_from_event( + workflow_execution=workflow_execution, + event=event, + status=WorkflowNodeExecutionStatus.RETRY, + error=event.error, + created_at=event.start_at, + ) + + # Handle inputs and outputs + inputs = WorkflowEntry.handle_special_values(event.inputs) + outputs = event.outputs + metadata = self._merge_event_metadata(event) + + domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=metadata) + + return self._save_and_cache_node_execution(domain_execution) + + def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution: + # Check cache first + if id in self._workflow_execution_cache: + return self._workflow_execution_cache[id] + + raise WorkflowRunNotFoundError(id) + + def _prepare_workflow_inputs(self) -> dict[str, Any]: + """Prepare workflow inputs by merging application inputs with system variables.""" + inputs = {**self._application_generate_entity.inputs} + + if self._workflow_system_variables: + for field_name, value in self._workflow_system_variables.to_dict().items(): + if field_name != SystemVariableKey.CONVERSATION_ID: + inputs[f"sys.{field_name}"] = value + + return dict(WorkflowEntry.handle_special_values(inputs) or {}) + + def _get_or_generate_execution_id(self) -> str: + """Get execution ID from system variables or generate a new one.""" + if self._workflow_system_variables and self._workflow_system_variables.workflow_execution_id: + return str(self._workflow_system_variables.workflow_execution_id) + return str(uuid4()) + + def _save_and_cache_workflow_execution(self, execution: WorkflowExecution) -> WorkflowExecution: + """Save workflow execution to repository and cache it.""" + self._workflow_execution_repository.save(execution) + self._workflow_execution_cache[execution.id_] = execution + return execution + + def _save_and_cache_node_execution(self, execution: WorkflowNodeExecution) -> WorkflowNodeExecution: + """Save node execution to repository and cache it if it has an ID.""" + self._workflow_node_execution_repository.save(execution) + if execution.node_execution_id: + self._node_execution_cache[execution.node_execution_id] = execution + return execution + + def _get_node_execution_from_cache(self, node_execution_id: str) -> WorkflowNodeExecution: + """Get node execution from cache or raise error if not found.""" + domain_execution = self._node_execution_cache.get(node_execution_id) + if not domain_execution: + raise ValueError(f"Domain node execution not found: {node_execution_id}") + return domain_execution + + def _update_workflow_execution_completion( + self, + execution: WorkflowExecution, + *, + status: WorkflowExecutionStatus, + total_tokens: int, + total_steps: int, + outputs: Mapping[str, Any] | None = None, + error_message: Optional[str] = None, + exceptions_count: int = 0, + finished_at: Optional[datetime] = None, + ) -> None: + """Update workflow execution with completion data.""" + execution.status = status + execution.outputs = outputs or {} + execution.total_tokens = total_tokens + execution.total_steps = total_steps + execution.finished_at = finished_at or naive_utc_now() + execution.exceptions_count = exceptions_count + if error_message: + execution.error_message = error_message + + def _add_trace_task_if_needed( + self, + trace_manager: Optional[TraceQueueManager], + workflow_execution: WorkflowExecution, + conversation_id: Optional[str], + ) -> None: + """Add trace task if trace manager is provided.""" + if trace_manager: + trace_manager.add_trace_task( + TraceTask( + TraceTaskName.WORKFLOW_TRACE, + workflow_execution=workflow_execution, + conversation_id=conversation_id, + user_id=trace_manager.user_id, + ) + ) + + def _fail_running_node_executions( + self, + workflow_execution_id: str, + error_message: str, + now: datetime, + ) -> None: + """Fail all running node executions for a workflow.""" + running_node_executions = [ + node_exec + for node_exec in self._node_execution_cache.values() + if node_exec.workflow_execution_id == workflow_execution_id + and node_exec.status == WorkflowNodeExecutionStatus.RUNNING + ] + + for node_execution in running_node_executions: + if node_execution.node_execution_id: + node_execution.status = WorkflowNodeExecutionStatus.FAILED + node_execution.error = error_message + node_execution.finished_at = now + node_execution.elapsed_time = (now - node_execution.created_at).total_seconds() + self._workflow_node_execution_repository.save(node_execution) + + def _create_node_execution_from_event( + self, + *, + workflow_execution: WorkflowExecution, + event: Union[QueueNodeStartedEvent, QueueNodeRetryEvent], + status: WorkflowNodeExecutionStatus, + error: Optional[str] = None, + created_at: Optional[datetime] = None, + ) -> WorkflowNodeExecution: + """Create a node execution from an event.""" + now = datetime.now(UTC).replace(tzinfo=None) + created_at = created_at or now + metadata = { WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id, @@ -232,152 +375,76 @@ class WorkflowCycleManager: node_id=event.node_id, node_type=event.node_type, title=event.node_data.title, - status=WorkflowNodeExecutionStatus.RUNNING, + status=status, metadata=metadata, created_at=created_at, + error=error, ) - # Use the instance repository to save the domain model - self._workflow_node_execution_repository.save(domain_execution) + if status == WorkflowNodeExecutionStatus.RETRY: + domain_execution.finished_at = now + domain_execution.elapsed_time = (now - created_at).total_seconds() return domain_execution - def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: - # Get the domain model from repository - domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) - if not domain_execution: - raise ValueError(f"Domain node execution not found: {event.node_execution_id}") - - # Process data - inputs = event.inputs - process_data = event.process_data - outputs = event.outputs - - # Convert metadata keys to strings - execution_metadata_dict = {} - if event.execution_metadata: - for key, value in event.execution_metadata.items(): - execution_metadata_dict[key] = value - - finished_at = datetime.now(UTC).replace(tzinfo=None) - elapsed_time = (finished_at - event.start_at).total_seconds() - - # Update domain model - domain_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED - domain_execution.update_from_mapping( - inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict - ) - domain_execution.finished_at = finished_at - domain_execution.elapsed_time = elapsed_time - - # Update the repository with the domain model - self._workflow_node_execution_repository.save(domain_execution) - - return domain_execution - - def handle_workflow_node_execution_failed( + def _update_node_execution_completion( self, + domain_execution: WorkflowNodeExecution, *, - event: QueueNodeFailedEvent - | QueueNodeInIterationFailedEvent - | QueueNodeInLoopFailedEvent - | QueueNodeExceptionEvent, - ) -> WorkflowNodeExecution: - """ - Workflow node execution failed - :param event: queue node failed event - :return: - """ - # Get the domain model from repository - domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) - if not domain_execution: - raise ValueError(f"Domain node execution not found: {event.node_execution_id}") - - # Process data - inputs = WorkflowEntry.handle_special_values(event.inputs) - process_data = WorkflowEntry.handle_special_values(event.process_data) - outputs = event.outputs - - # Convert metadata keys to strings - execution_metadata_dict = {} - if event.execution_metadata: - for key, value in event.execution_metadata.items(): - execution_metadata_dict[key] = value - + event: Union[ + QueueNodeSucceededEvent, + QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, + QueueNodeInLoopFailedEvent, + QueueNodeExceptionEvent, + ], + status: WorkflowNodeExecutionStatus, + error: Optional[str] = None, + handle_special_values: bool = False, + ) -> None: + """Update node execution with completion data.""" finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() + # Process data + if handle_special_values: + inputs = WorkflowEntry.handle_special_values(event.inputs) + process_data = WorkflowEntry.handle_special_values(event.process_data) + else: + inputs = event.inputs + process_data = event.process_data + + outputs = event.outputs + + # Convert metadata + execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, Any] = {} + if event.execution_metadata: + execution_metadata_dict.update(event.execution_metadata) + # Update domain model - domain_execution.status = ( - WorkflowNodeExecutionStatus.FAILED - if not isinstance(event, QueueNodeExceptionEvent) - else WorkflowNodeExecutionStatus.EXCEPTION - ) - domain_execution.error = event.error + domain_execution.status = status domain_execution.update_from_mapping( - inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict + inputs=inputs, + process_data=process_data, + outputs=outputs, + metadata=execution_metadata_dict, ) domain_execution.finished_at = finished_at domain_execution.elapsed_time = elapsed_time - # Update the repository with the domain model - self._workflow_node_execution_repository.save(domain_execution) + if error: + domain_execution.error = error - return domain_execution - - def handle_workflow_node_execution_retried( - self, *, workflow_execution_id: str, event: QueueNodeRetryEvent - ) -> WorkflowNodeExecution: - workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id) - created_at = event.start_at - finished_at = datetime.now(UTC).replace(tzinfo=None) - elapsed_time = (finished_at - created_at).total_seconds() - inputs = WorkflowEntry.handle_special_values(event.inputs) - outputs = event.outputs - - # Convert metadata keys to strings + def _merge_event_metadata(self, event: QueueNodeRetryEvent) -> dict[WorkflowNodeExecutionMetadataKey, str | None]: + """Merge event metadata with origin metadata.""" origin_metadata = { WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id, WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id, } - # Convert execution metadata keys to strings execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {} if event.execution_metadata: - for key, value in event.execution_metadata.items(): - execution_metadata_dict[key] = value + execution_metadata_dict.update(event.execution_metadata) - merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata - - # Create a domain model - domain_execution = WorkflowNodeExecution( - id=str(uuid4()), - workflow_id=workflow_execution.workflow_id, - workflow_execution_id=workflow_execution.id_, - predecessor_node_id=event.predecessor_node_id, - node_execution_id=event.node_execution_id, - node_id=event.node_id, - node_type=event.node_type, - title=event.node_data.title, - status=WorkflowNodeExecutionStatus.RETRY, - created_at=created_at, - finished_at=finished_at, - elapsed_time=elapsed_time, - error=event.error, - index=event.node_run_index, - ) - - # Update with mappings - domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=merged_metadata) - - # Use the instance repository to save the domain model - self._workflow_node_execution_repository.save(domain_execution) - - return domain_execution - - def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution: - execution = self._workflow_execution_repository.get(id) - if not execution: - raise WorkflowRunNotFoundError(id) - return execution + return {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata diff --git a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py index 642bc810b..4866db1fd 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py @@ -80,15 +80,12 @@ def real_workflow_system_variables(): @pytest.fixture def mock_node_execution_repository(): repo = MagicMock(spec=WorkflowNodeExecutionRepository) - repo.get_by_node_execution_id.return_value = None - repo.get_running_executions.return_value = [] return repo @pytest.fixture def mock_workflow_execution_repository(): repo = MagicMock(spec=WorkflowExecutionRepository) - repo.get.return_value = None return repo @@ -217,8 +214,8 @@ def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execu started_at=datetime.now(UTC).replace(tzinfo=None), ) - # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution - workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution + # Pre-populate the cache with the workflow execution + workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution # Call the method result = workflow_cycle_manager.handle_workflow_run_success( @@ -251,11 +248,10 @@ def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execut started_at=datetime.now(UTC).replace(tzinfo=None), ) - # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution - workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution + # Pre-populate the cache with the workflow execution + workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution - # Mock get_running_executions to return an empty list - workflow_cycle_manager._workflow_node_execution_repository.get_running_executions.return_value = [] + # No running node executions in cache (empty cache) # Call the method result = workflow_cycle_manager.handle_workflow_run_failed( @@ -289,8 +285,8 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu started_at=datetime.now(UTC).replace(tzinfo=None), ) - # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution - workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution + # Pre-populate the cache with the workflow execution + workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution # Create a mock event event = MagicMock(spec=QueueNodeStartedEvent) @@ -342,8 +338,8 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work started_at=datetime.now(UTC).replace(tzinfo=None), ) - # Mock the repository get method to return the real execution - workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution + # Pre-populate the cache with the workflow execution + workflow_cycle_manager._workflow_execution_cache["test-workflow-run-id"] = workflow_execution # Call the method result = workflow_cycle_manager._get_workflow_execution_or_raise_error("test-workflow-run-id") @@ -351,11 +347,13 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work # Verify the result assert result == workflow_execution - # Test error case - workflow_cycle_manager._workflow_execution_repository.get.return_value = None + # Test error case - clear cache + workflow_cycle_manager._workflow_execution_cache.clear() # Expect an error when execution is not found - with pytest.raises(ValueError): + from core.app.task_pipeline.exc import WorkflowRunNotFoundError + + with pytest.raises(WorkflowRunNotFoundError): workflow_cycle_manager._get_workflow_execution_or_raise_error("non-existent-id") @@ -384,8 +382,8 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager): created_at=datetime.now(UTC).replace(tzinfo=None), ) - # Mock the repository to return the node execution - workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution + # Pre-populate the cache with the node execution + workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution # Call the method result = workflow_cycle_manager.handle_workflow_node_execution_success( @@ -414,8 +412,8 @@ def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workfl started_at=datetime.now(UTC).replace(tzinfo=None), ) - # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution - workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution + # Pre-populate the cache with the workflow execution + workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution # Call the method result = workflow_cycle_manager.handle_workflow_run_partial_success( @@ -462,8 +460,8 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager): created_at=datetime.now(UTC).replace(tzinfo=None), ) - # Mock the repository to return the node execution - workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution + # Pre-populate the cache with the node execution + workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution # Call the method result = workflow_cycle_manager.handle_workflow_node_execution_failed( diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py index 643efb0a0..c60800c49 100644 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py @@ -137,37 +137,6 @@ def test_save_with_existing_tenant_id(repository, session): session_obj.merge.assert_called_once_with(modified_execution) -def test_get_by_node_execution_id(repository, session, mocker: MockerFixture): - """Test get_by_node_execution_id method.""" - session_obj, _ = session - # Set up mock - mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select") - mock_stmt = mocker.MagicMock() - mock_select.return_value = mock_stmt - mock_stmt.where.return_value = mock_stmt - - # Create a properly configured mock execution - mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel) - configure_mock_execution(mock_execution) - session_obj.scalar.return_value = mock_execution - - # Create a mock domain model to be returned by _to_domain_model - mock_domain_model = mocker.MagicMock() - # Mock the _to_domain_model method to return our mock domain model - repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model) - - # Call method - result = repository.get_by_node_execution_id("test-node-execution-id") - - # Assert select was called with correct parameters - mock_select.assert_called_once() - session_obj.scalar.assert_called_once_with(mock_stmt) - # Assert _to_domain_model was called with the mock execution - repository._to_domain_model.assert_called_once_with(mock_execution) - # Assert the result is our mock domain model - assert result is mock_domain_model - - def test_get_by_workflow_run(repository, session, mocker: MockerFixture): """Test get_by_workflow_run method.""" session_obj, _ = session @@ -202,88 +171,6 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture): assert result[0] is mock_domain_model -def test_get_running_executions(repository, session, mocker: MockerFixture): - """Test get_running_executions method.""" - session_obj, _ = session - # Set up mock - mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select") - mock_stmt = mocker.MagicMock() - mock_select.return_value = mock_stmt - mock_stmt.where.return_value = mock_stmt - - # Create a properly configured mock execution - mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel) - configure_mock_execution(mock_execution) - session_obj.scalars.return_value.all.return_value = [mock_execution] - - # Create a mock domain model to be returned by _to_domain_model - mock_domain_model = mocker.MagicMock() - # Mock the _to_domain_model method to return our mock domain model - repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model) - - # Call method - result = repository.get_running_executions("test-workflow-run-id") - - # Assert select was called with correct parameters - mock_select.assert_called_once() - session_obj.scalars.assert_called_once_with(mock_stmt) - # Assert _to_domain_model was called with the mock execution - repository._to_domain_model.assert_called_once_with(mock_execution) - # Assert the result contains our mock domain model - assert len(result) == 1 - assert result[0] is mock_domain_model - - -def test_update_via_save(repository, session): - """Test updating an existing record via save method.""" - session_obj, _ = session - # Create a mock execution - execution = MagicMock(spec=WorkflowNodeExecutionModel) - execution.tenant_id = None - execution.app_id = None - execution.inputs = None - execution.process_data = None - execution.outputs = None - execution.metadata = None - - # Mock the to_db_model method to return the execution itself - # This simulates the behavior of setting tenant_id and app_id - repository.to_db_model = MagicMock(return_value=execution) - - # Call save method to update an existing record - repository.save(execution) - - # Assert to_db_model was called with the execution - repository.to_db_model.assert_called_once_with(execution) - - # Assert session.merge was called (for updates) - session_obj.merge.assert_called_once_with(execution) - - -def test_clear(repository, session, mocker: MockerFixture): - """Test clear method.""" - session_obj, _ = session - # Set up mock - mock_delete = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.delete") - mock_stmt = mocker.MagicMock() - mock_delete.return_value = mock_stmt - mock_stmt.where.return_value = mock_stmt - - # Mock the execute result with rowcount - mock_result = mocker.MagicMock() - mock_result.rowcount = 5 # Simulate 5 records deleted - session_obj.execute.return_value = mock_result - - # Call method - repository.clear() - - # Assert delete was called with correct parameters - mock_delete.assert_called_once_with(WorkflowNodeExecutionModel) - mock_stmt.where.assert_called() - session_obj.execute.assert_called_once_with(mock_stmt) - session_obj.commit.assert_called_once() - - def test_to_db_model(repository): """Test to_db_model method.""" # Create a domain model