refactor: decouple Node and NodeData (#22581)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
-LAN-
2025-07-18 10:08:51 +08:00
committed by GitHub
parent 54c56f2d05
commit 460a825ef1
65 changed files with 2305 additions and 1146 deletions

View File

@@ -11,8 +11,9 @@ from core.variables.segments import ArrayFileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from .exc import (
CodeNodeError,
@@ -21,10 +22,32 @@ from .exc import (
)
class CodeNode(BaseNode[CodeNodeData]):
_node_data_cls = CodeNodeData
class CodeNode(BaseNode):
_node_type = NodeType.CODE
_node_data: CodeNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = CodeNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
@@ -47,12 +70,12 @@ class CodeNode(BaseNode[CodeNodeData]):
def _run(self) -> NodeRunResult:
# Get code language
code_language = self.node_data.code_language
code = self.node_data.code
code_language = self._node_data.code_language
code = self._node_data.code
# Get variables
variables = {}
for variable_selector in self.node_data.variables:
for variable_selector in self._node_data.variables:
variable_name = variable_selector.variable
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if isinstance(variable, ArrayFileSegment):
@@ -68,7 +91,7 @@ class CodeNode(BaseNode[CodeNodeData]):
)
# Transform result
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
result = self._transform_result(result=result, output_schema=self._node_data.outputs)
except (CodeExecutionError, CodeNodeError) as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
@@ -334,16 +357,20 @@ class CodeNode(BaseNode[CodeNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: CodeNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
# Create typed NodeData from dict
typed_node_data = CodeNodeData.model_validate(node_data)
return {
node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in node_data.variables
for variable_selector in typed_node_data.variables
}
@property
def continue_on_error(self) -> bool:
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled