feat: Iteration node support parallel mode (#9493)

This commit is contained in:
Novice
2024-11-05 10:32:49 +08:00
committed by GitHub
parent cca2e7876d
commit d1505b15c4
33 changed files with 1283 additions and 192 deletions

View File

@@ -20,6 +20,7 @@ from core.app.entities.queue_entities import (
QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@@ -314,7 +315,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent):
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
response = self._workflow_node_finish_to_stream_response(

View File

@@ -16,6 +16,7 @@ from core.app.entities.queue_entities import (
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@@ -275,7 +276,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent):
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
response = self._workflow_node_finish_to_stream_response(

View File

@@ -9,6 +9,7 @@ from core.app.entities.queue_entities import (
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@@ -30,6 +31,7 @@ from core.workflow.graph_engine.entities.event import (
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeInIterationFailedEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
@@ -193,6 +195,7 @@ class WorkflowBasedAppRunner(AppRunner):
node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
parallel_mode_run_id=event.parallel_mode_run_id,
)
)
elif isinstance(event, NodeRunSucceededEvent):
@@ -246,9 +249,40 @@ class WorkflowBasedAppRunner(AppRunner):
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeInIterationFailedEvent):
self._publish_event(
QueueNodeInIterationFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
error=event.error,
)
)
elif isinstance(event, NodeRunStreamChunkEvent):
self._publish_event(
QueueTextChunkEvent(
@@ -326,6 +360,7 @@ class WorkflowBasedAppRunner(AppRunner):
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_iteration_output,
parallel_mode_run_id=event.parallel_mode_run_id,
)
)
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):

View File

@@ -107,7 +107,8 @@ class QueueIterationNextEvent(AppQueueEvent):
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""
node_run_index: int
output: Optional[Any] = None # output for the current iteration
@@ -273,6 +274,8 @@ class QueueNodeStartedEvent(AppQueueEvent):
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""
class QueueNodeSucceededEvent(AppQueueEvent):
@@ -306,6 +309,37 @@ class QueueNodeSucceededEvent(AppQueueEvent):
error: Optional[str] = None
class QueueNodeInIterationFailedEvent(AppQueueEvent):
"""
QueueNodeInIterationFailedEvent entity
"""
event: QueueEvent = QueueEvent.NODE_FAILED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
error: str
class QueueNodeFailedEvent(AppQueueEvent):
"""
QueueNodeFailedEvent entity
@@ -332,6 +366,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
error: str

View File

@@ -244,6 +244,7 @@ class NodeStartStreamResponse(StreamResponse):
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
parallel_run_id: Optional[str] = None
event: StreamEvent = StreamEvent.NODE_STARTED
workflow_run_id: str
@@ -432,6 +433,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
extras: dict = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parallel_mode_run_id: Optional[str] = None
event: StreamEvent = StreamEvent.ITERATION_NEXT
workflow_run_id: str

View File

@@ -12,6 +12,7 @@ from core.app.entities.queue_entities import (
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@@ -35,6 +36,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
@@ -251,6 +253,12 @@ class WorkflowCycleManage:
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.execution_metadata = json.dumps(
{
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
}
)
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
session.add(workflow_node_execution)
@@ -305,7 +313,9 @@ class WorkflowCycleManage:
return workflow_node_execution
def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution:
def _handle_workflow_node_execution_failed(
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
@@ -318,16 +328,19 @@ class WorkflowCycleManage:
outputs = WorkflowEntry.handle_special_values(event.outputs)
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
)
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
{
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
WorkflowNodeExecution.error: event.error,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
WorkflowNodeExecution.finished_at: finished_at,
WorkflowNodeExecution.elapsed_time: elapsed_time,
WorkflowNodeExecution.execution_metadata: execution_metadata,
}
)
@@ -342,6 +355,7 @@ class WorkflowCycleManage:
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata
self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
@@ -448,6 +462,7 @@ class WorkflowCycleManage:
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
parallel_run_id=event.parallel_mode_run_id,
),
)
@@ -464,7 +479,7 @@ class WorkflowCycleManage:
def _workflow_node_finish_to_stream_response(
self,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
@@ -608,6 +623,7 @@ class WorkflowCycleManage:
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
),
)
@@ -633,7 +649,9 @@ class WorkflowCycleManage:
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
status=WorkflowNodeExecutionStatus.SUCCEEDED,
status=WorkflowNodeExecutionStatus.SUCCEEDED
if event.error is None
else WorkflowNodeExecutionStatus.FAILED,
error=None,
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,

View File

@@ -23,6 +23,7 @@ class NodeRunMetadataKey(str, Enum):
PARALLEL_START_NODE_ID = "parallel_start_node_id"
PARENT_PARALLEL_ID = "parent_parallel_id"
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
class NodeRunResult(BaseModel):

View File

@@ -59,6 +59,7 @@ class BaseNodeEvent(GraphEngineEvent):
class NodeRunStartedEvent(BaseNodeEvent):
predecessor_node_id: Optional[str] = None
parallel_mode_run_id: Optional[str] = None
"""predecessor node id"""
@@ -81,6 +82,10 @@ class NodeRunFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeInIterationFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
###########################################
# Parallel Branch Events
###########################################
@@ -129,6 +134,8 @@ class BaseIterationEvent(GraphEngineEvent):
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""
class IterationRunStartedEvent(BaseIterationEvent):

View File

@@ -4,6 +4,7 @@ import time
import uuid
from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy
from typing import Any, Optional
from flask import Flask, current_app
@@ -724,6 +725,16 @@ class GraphEngine:
"""
return time.perf_counter() - start_at > max_execution_time
def create_copy(self):
"""
create a graph engine copy
:return: with a new variable pool instance of graph engine
"""
new_instance = copy(self)
new_instance.graph_runtime_state = copy(self.graph_runtime_state)
new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool)
return new_instance
class GraphRunFailedError(Exception):
def __init__(self, error: str):

View File

@@ -1,3 +1,4 @@
from enum import Enum
from typing import Any, Optional
from pydantic import Field
@@ -5,6 +6,12 @@ from pydantic import Field
from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData
class ErrorHandleMode(str, Enum):
TERMINATED = "terminated"
CONTINUE_ON_ERROR = "continue-on-error"
REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output"
class IterationNodeData(BaseIterationNodeData):
"""
Iteration Node Data.
@@ -13,6 +20,9 @@ class IterationNodeData(BaseIterationNodeData):
parent_loop_id: Optional[str] = None # redundant field, not used currently
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
is_parallel: bool = False # open the parallel mode or not
parallel_nums: int = 10 # the numbers of parallel
error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error
class IterationStartNodeData(BaseNodeData):

View File

@@ -1,12 +1,20 @@
import logging
import uuid
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, wait
from datetime import datetime, timezone
from typing import Any, cast
from queue import Empty, Queue
from typing import TYPE_CHECKING, Any, Optional, cast
from flask import Flask, current_app
from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
from core.variables import IntegerSegment
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.node_entities import (
NodeRunMetadataKey,
NodeRunResult,
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
BaseGraphEvent,
BaseNodeEvent,
@@ -17,6 +25,9 @@ from core.workflow.graph_engine.entities.event import (
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeInIterationFailedEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
@@ -24,9 +35,11 @@ from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.iteration.entities import IterationNodeData
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from models.workflow import WorkflowNodeExecutionStatus
if TYPE_CHECKING:
from core.workflow.graph_engine.graph_engine import GraphEngine
logger = logging.getLogger(__name__)
@@ -38,6 +51,17 @@ class IterationNode(BaseNode[IterationNodeData]):
_node_data_cls = IterationNodeData
_node_type = NodeType.ITERATION
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
return {
"type": "iteration",
"config": {
"is_parallel": False,
"parallel_nums": 10,
"error_handle_mode": ErrorHandleMode.TERMINATED.value,
},
}
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
Run the node.
@@ -83,7 +107,7 @@ class IterationNode(BaseNode[IterationNodeData]):
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
# init graph engine
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
@@ -123,108 +147,64 @@ class IterationNode(BaseNode[IterationNodeData]):
index=0,
pre_iteration_output=None,
)
outputs: list[Any] = []
try:
for _ in range(len(iterator_list_value)):
# run workflow
rst = graph_engine.run()
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
if (
isinstance(event, BaseNodeEvent)
and event.node_type == NodeType.ITERATION_START
and not isinstance(event, NodeRunStreamChunkEvent)
):
if self.node_data.is_parallel:
futures: list[Future] = []
q = Queue()
thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100)
for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit(
self._run_single_iter_parallel,
current_app._get_current_object(),
q,
iterator_list_value,
inputs,
outputs,
start_at,
graph_engine,
iteration_graph,
index,
item,
)
future.add_done_callback(thread_pool.task_done_callback)
futures.append(future)
succeeded_count = 0
while True:
try:
event = q.get(timeout=1)
if event is None:
break
if isinstance(event, IterationRunNextEvent):
succeeded_count += 1
if succeeded_count == len(futures):
q.put(None)
yield event
if isinstance(event, RunCompletedEvent):
q.put(None)
for f in futures:
if not f.done():
f.cancel()
yield event
if isinstance(event, IterationRunFailedEvent):
q.put(None)
yield event
except Empty:
continue
if isinstance(event, NodeRunSucceededEvent):
if event.route_node_state.node_run_result:
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
index_variable = variable_pool.get([self.node_id, "index"])
if not isinstance(index_variable, IntegerSegment):
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=f"Invalid index variable type: {type(index_variable)}",
)
)
return
metadata[NodeRunMetadataKey.ITERATION_INDEX] = index_variable.value
event.route_node_state.node_run_result.metadata = metadata
yield event
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
)
)
return
else:
event = cast(InNodeEvent, event)
yield event
# append to iteration output variable list
current_iteration_output_variable = variable_pool.get(self.node_data.output_selector)
if current_iteration_output_variable is None:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=f"Iteration output variable {self.node_data.output_selector} not found",
)
# wait all threads
wait(futures)
else:
for _ in range(len(iterator_list_value)):
yield from self._run_single_iter(
iterator_list_value,
variable_pool,
inputs,
outputs,
start_at,
graph_engine,
iteration_graph,
)
return
current_iteration_output = current_iteration_output_variable.to_object()
outputs.append(current_iteration_output)
# remove all nodes outputs from variable pool
for node_id in iteration_graph.node_ids:
variable_pool.remove([node_id])
# move to next iteration
current_index_variable = variable_pool.get([self.node_id, "index"])
if not isinstance(current_index_variable, IntegerSegment):
raise ValueError(f"iteration {self.node_id} current index not found")
next_index = current_index_variable.value + 1
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index,
pre_iteration_output=jsonable_encoder(current_iteration_output),
)
yield IterationRunSucceededEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
@@ -330,3 +310,231 @@ class IterationNode(BaseNode[IterationNodeData]):
}
return variable_mapping
def _handle_event_metadata(
self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str
) -> NodeRunStartedEvent | BaseNodeEvent:
"""
add iteration metadata to event.
"""
if not isinstance(event, BaseNodeEvent):
return event
if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
event.parallel_mode_run_id = parallel_mode_run_id
return event
if event.route_node_state.node_run_result:
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
if self.node_data.is_parallel:
metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
else:
metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index
event.route_node_state.node_run_result.metadata = metadata
return event
def _run_single_iter(
self,
iterator_list_value: list[str],
variable_pool: VariablePool,
inputs: dict[str, list],
outputs: list,
start_at: datetime,
graph_engine: "GraphEngine",
iteration_graph: Graph,
parallel_mode_run_id: Optional[str] = None,
) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
run single iteration
"""
try:
rst = graph_engine.run()
# get current iteration index
current_index = variable_pool.get([self.node_id, "index"]).value
next_index = int(current_index) + 1
if current_index is None:
raise ValueError(f"iteration {self.node_id} current index not found")
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
if (
isinstance(event, BaseNodeEvent)
and event.node_type == NodeType.ITERATION_START
and not isinstance(event, NodeRunStreamChunkEvent)
):
continue
if isinstance(event, NodeRunSucceededEvent):
yield self._handle_event_metadata(event, current_index, parallel_mode_run_id)
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
if self.node_data.is_parallel:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
else:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
)
)
return
else:
event = cast(InNodeEvent, event)
metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id)
if isinstance(event, NodeRunFailedEvent):
if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
outputs.insert(current_index, None)
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
)
return
elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
)
return
elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": None},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
yield metadata_event
current_iteration_output = variable_pool.get(self.node_data.output_selector).value
outputs.insert(current_index, current_iteration_output)
# remove all nodes outputs from variable pool
for node_id in iteration_graph.node_ids:
variable_pool.remove([node_id])
# move to next iteration
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None,
)
except Exception as e:
logger.exception(f"Iteration run failed:{str(e)}")
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": None},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=str(e),
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
)
def _run_single_iter_parallel(
self,
flask_app: Flask,
q: Queue,
iterator_list_value: list[str],
inputs: dict[str, list],
outputs: list,
start_at: datetime,
graph_engine: "GraphEngine",
iteration_graph: Graph,
index: int,
item: Any,
) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
run single iteration in parallel mode
"""
with flask_app.app_context():
parallel_mode_run_id = uuid.uuid4().hex
graph_engine_copy = graph_engine.create_copy()
variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool
variable_pool_copy.add([self.node_id, "index"], index)
variable_pool_copy.add([self.node_id, "item"], item)
for event in self._run_single_iter(
iterator_list_value=iterator_list_value,
variable_pool=variable_pool_copy,
inputs=inputs,
outputs=outputs,
start_at=start_at,
graph_engine=graph_engine_copy,
iteration_graph=iteration_graph,
parallel_mode_run_id=parallel_mode_run_id,
):
q.put(event)