refactor: decouple Node and NodeData (#22581)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
-LAN-
2025-07-18 10:08:51 +08:00
committed by GitHub
parent 54c56f2d05
commit 460a825ef1
65 changed files with 2305 additions and 1146 deletions

View File

@@ -1,3 +1,4 @@
from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState
from .graph_engine import GraphEngine
__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
__all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]

View File

@@ -12,7 +12,7 @@ from typing import Any, Optional, cast
from flask import Flask, current_app
from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue
@@ -48,11 +48,9 @@ from core.workflow.nodes.agent.entities import AgentNodeData
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.utils import variable_utils
from libs.flask_utils import preserve_flask_contexts
from models.enums import UserFrom
@@ -260,12 +258,16 @@ class GraphEngine:
# convert to specific node
node_type = NodeType(node_config.get("data", {}).get("type"))
node_version = node_config.get("data", {}).get("version", "1")
# Import here to avoid circular import
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
# init workflow run state
node_instance = node_cls( # type: ignore
node = node_cls(
id=route_node_state.id,
config=node_config,
graph_init_params=self.init_params,
@@ -274,11 +276,11 @@ class GraphEngine:
previous_node_id=previous_node_id,
thread_pool_id=self.thread_pool_id,
)
node_instance = cast(BaseNode[BaseNodeData], node_instance)
node.init_node_data(node_config.get("data", {}))
try:
# run node
generator = self._run_node(
node_instance=node_instance,
node=node,
route_node_state=route_node_state,
parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
@@ -306,16 +308,16 @@ class GraphEngine:
route_node_state.failed_reason = str(e)
yield NodeRunFailedEvent(
error=str(e),
id=node_instance.id,
id=node.id,
node_id=next_node_id,
node_type=node_type,
node_data=node_instance.node_data,
node_data=node.get_base_node_data(),
route_node_state=route_node_state,
parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
node_version=node.version(),
)
raise e
@@ -337,7 +339,7 @@ class GraphEngine:
edge = edge_mappings[0]
if (
previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
and node.error_strategy == ErrorStrategy.FAIL_BRANCH
and edge.run_condition is None
):
break
@@ -413,8 +415,8 @@ class GraphEngine:
next_node_id = final_node_id
elif (
node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
and node_instance.should_continue_on_error
node.continue_on_error
and node.error_strategy == ErrorStrategy.FAIL_BRANCH
and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
):
break
@@ -597,7 +599,7 @@ class GraphEngine:
def _run_node(
self,
node_instance: BaseNode[BaseNodeData],
node: BaseNode,
route_node_state: RouteNodeState,
parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
@@ -611,29 +613,29 @@ class GraphEngine:
# trigger node run start event
agent_strategy = (
AgentNodeStrategyInit(
name=cast(AgentNodeData, node_instance.node_data).agent_strategy_name,
icon=cast(AgentNode, node_instance).agent_strategy_icon,
name=cast(AgentNodeData, node.get_base_node_data()).agent_strategy_name,
icon=cast(AgentNode, node).agent_strategy_icon,
)
if node_instance.node_type == NodeType.AGENT
if node.type_ == NodeType.AGENT
else None
)
yield NodeRunStartedEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
id=node.id,
node_id=node.node_id,
node_type=node.type_,
node_data=node.get_base_node_data(),
route_node_state=route_node_state,
predecessor_node_id=node_instance.previous_node_id,
predecessor_node_id=node.previous_node_id,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
agent_strategy=agent_strategy,
node_version=node_instance.version(),
node_version=node.version(),
)
max_retries = node_instance.node_data.retry_config.max_retries
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
max_retries = node.retry_config.max_retries
retry_interval = node.retry_config.retry_interval_seconds
retries = 0
should_continue_retry = True
while should_continue_retry and retries <= max_retries:
@@ -642,7 +644,7 @@ class GraphEngine:
retry_start_at = datetime.now(UTC).replace(tzinfo=None)
# yield control to other threads
time.sleep(0.001)
event_stream = node_instance.run()
event_stream = node.run()
for event in event_stream:
if isinstance(event, GraphEngineEvent):
# add parallel info to iteration event
@@ -658,21 +660,21 @@ class GraphEngine:
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if (
retries == max_retries
and node_instance.node_type == NodeType.HTTP_REQUEST
and node.type_ == NodeType.HTTP_REQUEST
and run_result.outputs
and not node_instance.should_continue_on_error
and not node.continue_on_error
):
run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
if node_instance.should_retry and retries < max_retries:
if node.retry and retries < max_retries:
retries += 1
route_node_state.node_run_result = run_result
yield NodeRunRetryEvent(
id=str(uuid.uuid4()),
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
node_id=node.node_id,
node_type=node.type_,
node_data=node.get_base_node_data(),
route_node_state=route_node_state,
predecessor_node_id=node_instance.previous_node_id,
predecessor_node_id=node.previous_node_id,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
@@ -680,17 +682,17 @@ class GraphEngine:
error=run_result.error or "Unknown error",
retry_index=retries,
start_at=retry_start_at,
node_version=node_instance.version(),
node_version=node.version(),
)
time.sleep(retry_interval)
break
route_node_state.set_finished(run_result=run_result)
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if node_instance.should_continue_on_error:
if node.continue_on_error:
# if run failed, handle error
run_result = self._handle_continue_on_error(
node_instance,
node,
event.run_result,
self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions,
@@ -701,44 +703,44 @@ class GraphEngine:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_instance.node_id,
node_id=node.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
yield NodeRunExceptionEvent(
error=run_result.error or "System Error",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
id=node.id,
node_id=node.node_id,
node_type=node.type_,
node_data=node.get_base_node_data(),
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
node_version=node.version(),
)
should_continue_retry = False
else:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
id=node.id,
node_id=node.node_id,
node_type=node.type_,
node_data=node.get_base_node_data(),
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
node_version=node.version(),
)
should_continue_retry = False
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if (
node_instance.should_continue_on_error
and self.graph.edge_mapping.get(node_instance.node_id)
and node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH
node.continue_on_error
and self.graph.edge_mapping.get(node.node_id)
and node.error_strategy is ErrorStrategy.FAIL_BRANCH
):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get(
@@ -758,7 +760,7 @@ class GraphEngine:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_instance.node_id,
node_id=node.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
@@ -783,26 +785,26 @@ class GraphEngine:
run_result.metadata = metadata_dict
yield NodeRunSucceededEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
id=node.id,
node_id=node.node_id,
node_type=node.type_,
node_data=node.get_base_node_data(),
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
node_version=node.version(),
)
should_continue_retry = False
break
elif isinstance(event, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
id=node.id,
node_id=node.node_id,
node_type=node.type_,
node_data=node.get_base_node_data(),
chunk_content=event.chunk_content,
from_variable_selector=event.from_variable_selector,
route_node_state=route_node_state,
@@ -810,14 +812,14 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
node_version=node.version(),
)
elif isinstance(event, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
id=node.id,
node_id=node.node_id,
node_type=node.type_,
node_data=node.get_base_node_data(),
retriever_resources=event.retriever_resources,
context=event.context,
route_node_state=route_node_state,
@@ -825,7 +827,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
node_version=node.version(),
)
except GenerateTaskStoppedError:
# trigger node run failed event
@@ -833,20 +835,20 @@ class GraphEngine:
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
error="Workflow stopped.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
id=node.id,
node_id=node.node_id,
node_type=node.type_,
node_data=node.get_base_node_data(),
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
node_version=node.version(),
)
return
except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed")
logger.exception(f"Node {node.title} run failed")
raise e
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
@@ -886,22 +888,14 @@ class GraphEngine:
def _handle_continue_on_error(
self,
node_instance: BaseNode[BaseNodeData],
node: BaseNode,
error_result: NodeRunResult,
variable_pool: VariablePool,
handle_exceptions: list[str] = [],
) -> NodeRunResult:
"""
handle continue on error when self._should_continue_on_error is True
:param error_result (NodeRunResult): error run result
:param variable_pool (VariablePool): variable pool
:return: excption run result
"""
# add error message and error type to variable pool
variable_pool.add([node_instance.node_id, "error_message"], error_result.error)
variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type)
variable_pool.add([node.node_id, "error_message"], error_result.error)
variable_pool.add([node.node_id, "error_type"], error_result.error_type)
# add error message to handle_exceptions
handle_exceptions.append(error_result.error or "")
node_error_args: dict[str, Any] = {
@@ -909,21 +903,21 @@ class GraphEngine:
"error": error_result.error,
"inputs": error_result.inputs,
"metadata": {
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node.error_strategy,
},
}
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
if node.error_strategy is ErrorStrategy.DEFAULT_VALUE:
return NodeRunResult(
**node_error_args,
outputs={
**node_instance.node_data.default_value_dict,
**node.default_value_dict,
"error_message": error_result.error,
"error_type": error_result.error_type,
},
)
elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH:
if self.graph.edge_mapping.get(node_instance.node_id):
elif node.error_strategy is ErrorStrategy.FAIL_BRANCH:
if self.graph.edge_mapping.get(node.node_id):
node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED
return NodeRunResult(
**node_error_args,