Feat/iteration single run time (#10512)

This commit is contained in:
Novice
2024-11-11 14:47:52 +08:00
committed by GitHub
parent 0c1307b083
commit f414d241c1
16 changed files with 101 additions and 29 deletions

View File

@@ -156,6 +156,7 @@ class IterationNode(BaseNode[IterationNodeData]):
index=0,
pre_iteration_output=None,
)
iter_run_map: dict[str, float] = {}
outputs: list[Any] = [None] * len(iterator_list_value)
try:
if self.node_data.is_parallel:
@@ -175,6 +176,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_graph,
index,
item,
iter_run_map,
)
future.add_done_callback(thread_pool.task_done_callback)
futures.append(future)
@@ -213,6 +215,7 @@ class IterationNode(BaseNode[IterationNodeData]):
start_at,
graph_engine,
iteration_graph,
iter_run_map,
)
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs = [output for output in outputs if output is not None]
@@ -230,7 +233,9 @@ class IterationNode(BaseNode[IterationNodeData]):
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": jsonable_encoder(outputs)}
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": jsonable_encoder(outputs)},
metadata={NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map},
)
)
except IterationNodeError as e:
@@ -356,15 +361,19 @@ class IterationNode(BaseNode[IterationNodeData]):
start_at: datetime,
graph_engine: "GraphEngine",
iteration_graph: Graph,
iter_run_map: dict[str, float],
parallel_mode_run_id: Optional[str] = None,
) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
run single iteration
"""
iter_start_at = datetime.now(timezone.utc).replace(tzinfo=None)
try:
rst = graph_engine.run()
# get current iteration index
current_index = variable_pool.get([self.node_id, "index"]).value
iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}"
next_index = int(current_index) + 1
if current_index is None:
@@ -431,6 +440,8 @@ class IterationNode(BaseNode[IterationNodeData]):
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
@@ -439,6 +450,7 @@ class IterationNode(BaseNode[IterationNodeData]):
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
duration=duration,
)
return
elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
@@ -449,6 +461,8 @@ class IterationNode(BaseNode[IterationNodeData]):
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
@@ -457,6 +471,7 @@ class IterationNode(BaseNode[IterationNodeData]):
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
duration=duration,
)
return
elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
@@ -485,6 +500,8 @@ class IterationNode(BaseNode[IterationNodeData]):
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
@@ -493,6 +510,7 @@ class IterationNode(BaseNode[IterationNodeData]):
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None,
duration=duration,
)
except IterationNodeError as e:
@@ -528,6 +546,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_graph: Graph,
index: int,
item: Any,
iter_run_map: dict[str, float],
) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
run single iteration in parallel mode
@@ -546,6 +565,7 @@ class IterationNode(BaseNode[IterationNodeData]):
start_at=start_at,
graph_engine=graph_engine_copy,
iteration_graph=iteration_graph,
iter_run_map=iter_run_map,
parallel_mode_run_id=parallel_mode_run_id,
):
q.put(event)