refactor: decouple Node and NodeData (#22581)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
-LAN-
2025-07-18 10:08:51 +08:00
committed by GitHub
parent 54c56f2d05
commit 460a825ef1
65 changed files with 2305 additions and 1146 deletions

View File

@@ -1,18 +1,44 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.loop.entities import LoopEndNodeData
class LoopEndNode(BaseNode[LoopEndNodeData]):
class LoopEndNode(BaseNode):
"""
Loop End Node.
"""
_node_data_cls = LoopEndNodeData
_node_type = NodeType.LOOP_END
_node_data: LoopEndNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = LoopEndNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@@ -3,7 +3,7 @@ import logging
import time
from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from configs import dify_config
from core.variables import (
@@ -30,7 +30,8 @@ from core.workflow.graph_engine.entities.event import (
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.loop.entities import LoopNodeData
from core.workflow.utils.condition.processor import ConditionProcessor
@@ -43,14 +44,36 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class LoopNode(BaseNode[LoopNodeData]):
class LoopNode(BaseNode):
"""
Loop Node.
"""
_node_data_cls = LoopNodeData
_node_type = NodeType.LOOP
_node_data: LoopNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = LoopNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
@@ -58,17 +81,17 @@ class LoopNode(BaseNode[LoopNodeData]):
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""Run the node."""
# Get inputs
loop_count = self.node_data.loop_count
break_conditions = self.node_data.break_conditions
logical_operator = self.node_data.logical_operator
loop_count = self._node_data.loop_count
break_conditions = self._node_data.break_conditions
logical_operator = self._node_data.logical_operator
inputs = {"loop_count": loop_count}
if not self.node_data.start_node_id:
if not self._node_data.start_node_id:
raise ValueError(f"field start_node_id in loop {self.node_id} not found")
# Initialize graph
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self.node_data.start_node_id)
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id)
if not loop_graph:
raise ValueError("loop graph not found")
@@ -78,8 +101,8 @@ class LoopNode(BaseNode[LoopNodeData]):
# Initialize loop variables
loop_variable_selectors = {}
if self.node_data.loop_variables:
for loop_variable in self.node_data.loop_variables:
if self._node_data.loop_variables:
for loop_variable in self._node_data.loop_variables:
value_processor = {
"constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value),
"variable": lambda var=loop_variable: variable_pool.get(var.value),
@@ -127,8 +150,8 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunStartedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
loop_node_type=self.type_,
loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
metadata={"loop_length": loop_count},
@@ -184,11 +207,11 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunSucceededEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
loop_node_type=self.type_,
loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs=self.node_data.outputs,
outputs=self._node_data.outputs,
steps=loop_count,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
@@ -206,7 +229,7 @@ class LoopNode(BaseNode[LoopNodeData]):
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
outputs=self.node_data.outputs,
outputs=self._node_data.outputs,
inputs=inputs,
)
)
@@ -217,8 +240,8 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
loop_node_type=self.type_,
loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
steps=loop_count,
@@ -320,8 +343,8 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
loop_node_type=self.type_,
loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
steps=current_index,
@@ -351,8 +374,8 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
loop_node_type=self.type_,
loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
steps=current_index,
@@ -388,7 +411,7 @@ class LoopNode(BaseNode[LoopNodeData]):
_outputs[loop_variable_key] = None
_outputs["loop_round"] = current_index + 1
self.node_data.outputs = _outputs
self._node_data.outputs = _outputs
if check_break_result:
return {"check_break_result": True}
@@ -400,10 +423,10 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunNextEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
loop_node_type=self.type_,
loop_node_data=self._node_data,
index=next_index,
pre_loop_output=self.node_data.outputs,
pre_loop_output=self._node_data.outputs,
)
return {"check_break_result": False}
@@ -438,19 +461,15 @@ class LoopNode(BaseNode[LoopNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: LoopNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
# Create typed NodeData from dict
typed_node_data = LoopNodeData.model_validate(node_data)
variable_mapping = {}
# init graph
loop_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id)
loop_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
if not loop_graph:
raise ValueError("loop graph not found")
@@ -486,7 +505,7 @@ class LoopNode(BaseNode[LoopNodeData]):
variable_mapping.update(sub_node_variable_mapping)
for loop_variable in node_data.loop_variables or []:
for loop_variable in typed_node_data.loop_variables or []:
if loop_variable.value_type == "variable":
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
# add loop variable to variable mapping

View File

@@ -1,18 +1,44 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.loop.entities import LoopStartNodeData
class LoopStartNode(BaseNode[LoopStartNodeData]):
class LoopStartNode(BaseNode):
"""
Loop Start Node.
"""
_node_data_cls = LoopStartNodeData
_node_type = NodeType.LOOP_START
_node_data: LoopStartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = LoopStartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"