fix: loop node doesn't exit when it react the condition #24717 (#24844)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
coolfinish
2025-09-05 01:38:52 -05:00
committed by GitHub
parent 1ba69b8abf
commit cd95237ae4

View File

@@ -289,6 +289,8 @@ class LoopNode(BaseNode):
Returns: Returns:
dict: {'check_break_result': bool} dict: {'check_break_result': bool}
""" """
condition_selectors = self._extract_selectors_from_conditions(break_conditions)
extended_selectors = {**loop_variable_selectors, **condition_selectors}
# Run workflow # Run workflow
rst = graph_engine.run() rst = graph_engine.run()
current_index_variable = variable_pool.get([self.node_id, "index"]) current_index_variable = variable_pool.get([self.node_id, "index"])
@@ -314,24 +316,6 @@ class LoopNode(BaseNode):
and event.node_type == NodeType.LOOP_END and event.node_type == NodeType.LOOP_END
and not isinstance(event, NodeRunStreamChunkEvent) and not isinstance(event, NodeRunStreamChunkEvent)
): ):
# Check if variables in break conditions exist and process conditions
# Allow loop internal variables to be used in break conditions
available_conditions = []
for condition in break_conditions:
variable = self.graph_runtime_state.variable_pool.get(condition.variable_selector)
if variable:
available_conditions.append(condition)
# Process conditions if at least one variable is available
if available_conditions:
_, _, check_break_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=available_conditions,
operator=logical_operator,
)
if check_break_result:
break
else:
check_break_result = True check_break_result = True
yield self._handle_event_metadata(event=event, iter_run_index=current_index) yield self._handle_event_metadata(event=event, iter_run_index=current_index)
break break
@@ -339,6 +323,23 @@ class LoopNode(BaseNode):
if isinstance(event, NodeRunSucceededEvent): if isinstance(event, NodeRunSucceededEvent):
yield self._handle_event_metadata(event=event, iter_run_index=current_index) yield self._handle_event_metadata(event=event, iter_run_index=current_index)
# Check if all variables in break conditions exist
exists_variable = False
for condition in break_conditions:
if not self.graph_runtime_state.variable_pool.get(condition.variable_selector):
exists_variable = False
break
else:
exists_variable = True
if exists_variable:
input_conditions, group_result, check_break_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=break_conditions,
operator=logical_operator,
)
if check_break_result:
break
elif isinstance(event, BaseGraphEvent): elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent): if isinstance(event, GraphRunFailedEvent):
# Loop run failed # Loop run failed
@@ -400,12 +401,8 @@ class LoopNode(BaseNode):
else: else:
yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index) yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index)
# Remove all nodes outputs from variable pool
for node_id in loop_graph.node_ids:
variable_pool.remove([node_id])
_outputs: dict[str, Segment | int | None] = {} _outputs: dict[str, Segment | int | None] = {}
for loop_variable_key, loop_variable_selector in loop_variable_selectors.items(): for loop_variable_key, loop_variable_selector in extended_selectors.items():
_loop_variable_segment = variable_pool.get(loop_variable_selector) _loop_variable_segment = variable_pool.get(loop_variable_selector)
if _loop_variable_segment: if _loop_variable_segment:
_outputs[loop_variable_key] = _loop_variable_segment _outputs[loop_variable_key] = _loop_variable_segment
@@ -415,6 +412,10 @@ class LoopNode(BaseNode):
_outputs["loop_round"] = current_index + 1 _outputs["loop_round"] = current_index + 1
self._node_data.outputs = _outputs self._node_data.outputs = _outputs
# Remove all nodes outputs from variable pool
for node_id in loop_graph.node_ids:
variable_pool.remove([node_id])
if check_break_result: if check_break_result:
return {"check_break_result": True} return {"check_break_result": True}
@@ -433,6 +434,13 @@ class LoopNode(BaseNode):
return {"check_break_result": False} return {"check_break_result": False}
def _extract_selectors_from_conditions(self, conditions: list) -> dict[str, list[str]]:
return {
condition.variable_selector[1]: condition.variable_selector
for condition in conditions
if condition.variable_selector and len(condition.variable_selector) >= 2
}
def _handle_event_metadata( def _handle_event_metadata(
self, self,
*, *,