refactor: decouple Node and NodeData (#22581)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState
|
||||
from .graph_engine import GraphEngine
|
||||
|
||||
__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
|
||||
__all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
|
||||
|
@@ -12,7 +12,7 @@ from typing import Any, Optional, cast
|
||||
from flask import Flask, current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
@@ -48,11 +48,9 @@ from core.workflow.nodes.agent.entities import AgentNodeData
|
||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.utils import variable_utils
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models.enums import UserFrom
|
||||
@@ -260,12 +258,16 @@ class GraphEngine:
|
||||
# convert to specific node
|
||||
node_type = NodeType(node_config.get("data", {}).get("type"))
|
||||
node_version = node_config.get("data", {}).get("version", "1")
|
||||
|
||||
# Import here to avoid circular import
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
|
||||
|
||||
# init workflow run state
|
||||
node_instance = node_cls( # type: ignore
|
||||
node = node_cls(
|
||||
id=route_node_state.id,
|
||||
config=node_config,
|
||||
graph_init_params=self.init_params,
|
||||
@@ -274,11 +276,11 @@ class GraphEngine:
|
||||
previous_node_id=previous_node_id,
|
||||
thread_pool_id=self.thread_pool_id,
|
||||
)
|
||||
node_instance = cast(BaseNode[BaseNodeData], node_instance)
|
||||
node.init_node_data(node_config.get("data", {}))
|
||||
try:
|
||||
# run node
|
||||
generator = self._run_node(
|
||||
node_instance=node_instance,
|
||||
node=node,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
@@ -306,16 +308,16 @@ class GraphEngine:
|
||||
route_node_state.failed_reason = str(e)
|
||||
yield NodeRunFailedEvent(
|
||||
error=str(e),
|
||||
id=node_instance.id,
|
||||
id=node.id,
|
||||
node_id=next_node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_instance.node_data,
|
||||
node_data=node.get_base_node_data(),
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
raise e
|
||||
|
||||
@@ -337,7 +339,7 @@ class GraphEngine:
|
||||
edge = edge_mappings[0]
|
||||
if (
|
||||
previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
|
||||
and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
|
||||
and node.error_strategy == ErrorStrategy.FAIL_BRANCH
|
||||
and edge.run_condition is None
|
||||
):
|
||||
break
|
||||
@@ -413,8 +415,8 @@ class GraphEngine:
|
||||
|
||||
next_node_id = final_node_id
|
||||
elif (
|
||||
node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
|
||||
and node_instance.should_continue_on_error
|
||||
node.continue_on_error
|
||||
and node.error_strategy == ErrorStrategy.FAIL_BRANCH
|
||||
and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
|
||||
):
|
||||
break
|
||||
@@ -597,7 +599,7 @@ class GraphEngine:
|
||||
|
||||
def _run_node(
|
||||
self,
|
||||
node_instance: BaseNode[BaseNodeData],
|
||||
node: BaseNode,
|
||||
route_node_state: RouteNodeState,
|
||||
parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None,
|
||||
@@ -611,29 +613,29 @@ class GraphEngine:
|
||||
# trigger node run start event
|
||||
agent_strategy = (
|
||||
AgentNodeStrategyInit(
|
||||
name=cast(AgentNodeData, node_instance.node_data).agent_strategy_name,
|
||||
icon=cast(AgentNode, node_instance).agent_strategy_icon,
|
||||
name=cast(AgentNodeData, node.get_base_node_data()).agent_strategy_name,
|
||||
icon=cast(AgentNode, node).agent_strategy_icon,
|
||||
)
|
||||
if node_instance.node_type == NodeType.AGENT
|
||||
if node.type_ == NodeType.AGENT
|
||||
else None
|
||||
)
|
||||
yield NodeRunStartedEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
id=node.id,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
route_node_state=route_node_state,
|
||||
predecessor_node_id=node_instance.previous_node_id,
|
||||
predecessor_node_id=node.previous_node_id,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
agent_strategy=agent_strategy,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
|
||||
max_retries = node_instance.node_data.retry_config.max_retries
|
||||
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
|
||||
max_retries = node.retry_config.max_retries
|
||||
retry_interval = node.retry_config.retry_interval_seconds
|
||||
retries = 0
|
||||
should_continue_retry = True
|
||||
while should_continue_retry and retries <= max_retries:
|
||||
@@ -642,7 +644,7 @@ class GraphEngine:
|
||||
retry_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
# yield control to other threads
|
||||
time.sleep(0.001)
|
||||
event_stream = node_instance.run()
|
||||
event_stream = node.run()
|
||||
for event in event_stream:
|
||||
if isinstance(event, GraphEngineEvent):
|
||||
# add parallel info to iteration event
|
||||
@@ -658,21 +660,21 @@ class GraphEngine:
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if (
|
||||
retries == max_retries
|
||||
and node_instance.node_type == NodeType.HTTP_REQUEST
|
||||
and node.type_ == NodeType.HTTP_REQUEST
|
||||
and run_result.outputs
|
||||
and not node_instance.should_continue_on_error
|
||||
and not node.continue_on_error
|
||||
):
|
||||
run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
if node_instance.should_retry and retries < max_retries:
|
||||
if node.retry and retries < max_retries:
|
||||
retries += 1
|
||||
route_node_state.node_run_result = run_result
|
||||
yield NodeRunRetryEvent(
|
||||
id=str(uuid.uuid4()),
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
route_node_state=route_node_state,
|
||||
predecessor_node_id=node_instance.previous_node_id,
|
||||
predecessor_node_id=node.previous_node_id,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
@@ -680,17 +682,17 @@ class GraphEngine:
|
||||
error=run_result.error or "Unknown error",
|
||||
retry_index=retries,
|
||||
start_at=retry_start_at,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
time.sleep(retry_interval)
|
||||
break
|
||||
route_node_state.set_finished(run_result=run_result)
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if node_instance.should_continue_on_error:
|
||||
if node.continue_on_error:
|
||||
# if run failed, handle error
|
||||
run_result = self._handle_continue_on_error(
|
||||
node_instance,
|
||||
node,
|
||||
event.run_result,
|
||||
self.graph_runtime_state.variable_pool,
|
||||
handle_exceptions=handle_exceptions,
|
||||
@@ -701,44 +703,44 @@ class GraphEngine:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
node_id=node.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
)
|
||||
yield NodeRunExceptionEvent(
|
||||
error=run_result.error or "System Error",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
id=node.id,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
should_continue_retry = False
|
||||
else:
|
||||
yield NodeRunFailedEvent(
|
||||
error=route_node_state.failed_reason or "Unknown error.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
id=node.id,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
should_continue_retry = False
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if (
|
||||
node_instance.should_continue_on_error
|
||||
and self.graph.edge_mapping.get(node_instance.node_id)
|
||||
and node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH
|
||||
node.continue_on_error
|
||||
and self.graph.edge_mapping.get(node.node_id)
|
||||
and node.error_strategy is ErrorStrategy.FAIL_BRANCH
|
||||
):
|
||||
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
|
||||
if run_result.metadata and run_result.metadata.get(
|
||||
@@ -758,7 +760,7 @@ class GraphEngine:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
node_id=node.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
)
|
||||
@@ -783,26 +785,26 @@ class GraphEngine:
|
||||
run_result.metadata = metadata_dict
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
id=node.id,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
should_continue_retry = False
|
||||
|
||||
break
|
||||
elif isinstance(event, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
id=node.id,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
chunk_content=event.chunk_content,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
route_node_state=route_node_state,
|
||||
@@ -810,14 +812,14 @@ class GraphEngine:
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
elif isinstance(event, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
id=node.id,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
retriever_resources=event.retriever_resources,
|
||||
context=event.context,
|
||||
route_node_state=route_node_state,
|
||||
@@ -825,7 +827,7 @@ class GraphEngine:
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
except GenerateTaskStoppedError:
|
||||
# trigger node run failed event
|
||||
@@ -833,20 +835,20 @@ class GraphEngine:
|
||||
route_node_state.failed_reason = "Workflow stopped."
|
||||
yield NodeRunFailedEvent(
|
||||
error="Workflow stopped.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
id=node.id,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {node_instance.node_data.title} run failed")
|
||||
logger.exception(f"Node {node.title} run failed")
|
||||
raise e
|
||||
|
||||
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
||||
@@ -886,22 +888,14 @@ class GraphEngine:
|
||||
|
||||
def _handle_continue_on_error(
|
||||
self,
|
||||
node_instance: BaseNode[BaseNodeData],
|
||||
node: BaseNode,
|
||||
error_result: NodeRunResult,
|
||||
variable_pool: VariablePool,
|
||||
handle_exceptions: list[str] = [],
|
||||
) -> NodeRunResult:
|
||||
"""
|
||||
handle continue on error when self._should_continue_on_error is True
|
||||
|
||||
|
||||
:param error_result (NodeRunResult): error run result
|
||||
:param variable_pool (VariablePool): variable pool
|
||||
:return: excption run result
|
||||
"""
|
||||
# add error message and error type to variable pool
|
||||
variable_pool.add([node_instance.node_id, "error_message"], error_result.error)
|
||||
variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type)
|
||||
variable_pool.add([node.node_id, "error_message"], error_result.error)
|
||||
variable_pool.add([node.node_id, "error_type"], error_result.error_type)
|
||||
# add error message to handle_exceptions
|
||||
handle_exceptions.append(error_result.error or "")
|
||||
node_error_args: dict[str, Any] = {
|
||||
@@ -909,21 +903,21 @@ class GraphEngine:
|
||||
"error": error_result.error,
|
||||
"inputs": error_result.inputs,
|
||||
"metadata": {
|
||||
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
|
||||
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node.error_strategy,
|
||||
},
|
||||
}
|
||||
|
||||
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
|
||||
if node.error_strategy is ErrorStrategy.DEFAULT_VALUE:
|
||||
return NodeRunResult(
|
||||
**node_error_args,
|
||||
outputs={
|
||||
**node_instance.node_data.default_value_dict,
|
||||
**node.default_value_dict,
|
||||
"error_message": error_result.error,
|
||||
"error_type": error_result.error_type,
|
||||
},
|
||||
)
|
||||
elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH:
|
||||
if self.graph.edge_mapping.get(node_instance.node_id):
|
||||
elif node.error_strategy is ErrorStrategy.FAIL_BRANCH:
|
||||
if self.graph.edge_mapping.get(node.node_id):
|
||||
node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED
|
||||
return NodeRunResult(
|
||||
**node_error_args,
|
||||
|
Reference in New Issue
Block a user