refactor: decouple Node and NodeData (#22581)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
@@ -1,18 +1,44 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.loop.entities import LoopEndNodeData
|
||||
|
||||
|
||||
class LoopEndNode(BaseNode[LoopEndNodeData]):
|
||||
class LoopEndNode(BaseNode):
|
||||
"""
|
||||
Loop End Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = LoopEndNodeData
|
||||
_node_type = NodeType.LOOP_END
|
||||
|
||||
_node_data: LoopEndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = LoopEndNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
@@ -3,7 +3,7 @@ import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.variables import (
|
||||
@@ -30,7 +30,8 @@ from core.workflow.graph_engine.entities.event import (
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
@@ -43,14 +44,36 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoopNode(BaseNode[LoopNodeData]):
|
||||
class LoopNode(BaseNode):
|
||||
"""
|
||||
Loop Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = LoopNodeData
|
||||
_node_type = NodeType.LOOP
|
||||
|
||||
_node_data: LoopNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = LoopNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
@@ -58,17 +81,17 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""Run the node."""
|
||||
# Get inputs
|
||||
loop_count = self.node_data.loop_count
|
||||
break_conditions = self.node_data.break_conditions
|
||||
logical_operator = self.node_data.logical_operator
|
||||
loop_count = self._node_data.loop_count
|
||||
break_conditions = self._node_data.break_conditions
|
||||
logical_operator = self._node_data.logical_operator
|
||||
|
||||
inputs = {"loop_count": loop_count}
|
||||
|
||||
if not self.node_data.start_node_id:
|
||||
if not self._node_data.start_node_id:
|
||||
raise ValueError(f"field start_node_id in loop {self.node_id} not found")
|
||||
|
||||
# Initialize graph
|
||||
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self.node_data.start_node_id)
|
||||
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id)
|
||||
if not loop_graph:
|
||||
raise ValueError("loop graph not found")
|
||||
|
||||
@@ -78,8 +101,8 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
|
||||
# Initialize loop variables
|
||||
loop_variable_selectors = {}
|
||||
if self.node_data.loop_variables:
|
||||
for loop_variable in self.node_data.loop_variables:
|
||||
if self._node_data.loop_variables:
|
||||
for loop_variable in self._node_data.loop_variables:
|
||||
value_processor = {
|
||||
"constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value),
|
||||
"variable": lambda var=loop_variable: variable_pool.get(var.value),
|
||||
@@ -127,8 +150,8 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
yield LoopRunStartedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
metadata={"loop_length": loop_count},
|
||||
@@ -184,11 +207,11 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
yield LoopRunSucceededEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs=self.node_data.outputs,
|
||||
outputs=self._node_data.outputs,
|
||||
steps=loop_count,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
@@ -206,7 +229,7 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
outputs=self.node_data.outputs,
|
||||
outputs=self._node_data.outputs,
|
||||
inputs=inputs,
|
||||
)
|
||||
)
|
||||
@@ -217,8 +240,8 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=loop_count,
|
||||
@@ -320,8 +343,8 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=current_index,
|
||||
@@ -351,8 +374,8 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=current_index,
|
||||
@@ -388,7 +411,7 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
_outputs[loop_variable_key] = None
|
||||
|
||||
_outputs["loop_round"] = current_index + 1
|
||||
self.node_data.outputs = _outputs
|
||||
self._node_data.outputs = _outputs
|
||||
|
||||
if check_break_result:
|
||||
return {"check_break_result": True}
|
||||
@@ -400,10 +423,10 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
yield LoopRunNextEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
index=next_index,
|
||||
pre_loop_output=self.node_data.outputs,
|
||||
pre_loop_output=self._node_data.outputs,
|
||||
)
|
||||
|
||||
return {"check_break_result": False}
|
||||
@@ -438,19 +461,15 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: LoopNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = LoopNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {}
|
||||
|
||||
# init graph
|
||||
loop_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id)
|
||||
loop_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
|
||||
|
||||
if not loop_graph:
|
||||
raise ValueError("loop graph not found")
|
||||
@@ -486,7 +505,7 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
|
||||
variable_mapping.update(sub_node_variable_mapping)
|
||||
|
||||
for loop_variable in node_data.loop_variables or []:
|
||||
for loop_variable in typed_node_data.loop_variables or []:
|
||||
if loop_variable.value_type == "variable":
|
||||
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
|
||||
# add loop variable to variable mapping
|
||||
|
@@ -1,18 +1,44 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.loop.entities import LoopStartNodeData
|
||||
|
||||
|
||||
class LoopStartNode(BaseNode[LoopStartNodeData]):
|
||||
class LoopStartNode(BaseNode):
|
||||
"""
|
||||
Loop Start Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = LoopStartNodeData
|
||||
_node_type = NodeType.LOOP_START
|
||||
|
||||
_node_data: LoopStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = LoopStartNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
Reference in New Issue
Block a user