Feat: continue on error (#11458)
Co-authored-by: Novice Lee <novicelee@NovicedeMacBook-Pro.local> Co-authored-by: Novice Lee <novicelee@NoviPro.local>
This commit is contained in:
@@ -4,6 +4,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
IterationRunFailedEvent,
|
||||
@@ -39,6 +40,8 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
self.print_text("\n[GraphRunStartedEvent]", color="pink")
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self.print_text("\n[GraphRunSucceededEvent]", color="green")
|
||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||
self.print_text("\n[GraphRunPartialSucceededEvent]", color="pink")
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
|
@@ -25,6 +25,7 @@ class NodeRunMetadataKey(StrEnum):
|
||||
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
||||
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
||||
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
@@ -43,3 +44,4 @@ class NodeRunResult(BaseModel):
|
||||
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
||||
|
||||
error: Optional[str] = None # error message if status is failed
|
||||
error_type: Optional[str] = None # error type if status is failed
|
||||
|
@@ -33,6 +33,12 @@ class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
|
||||
class GraphRunFailedEvent(BaseGraphEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
exceptions_count: Optional[int] = Field(description="exception count", default=0)
|
||||
|
||||
|
||||
class GraphRunPartialSucceededEvent(BaseGraphEvent):
|
||||
exceptions_count: int = Field(..., description="exception count")
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
###########################################
|
||||
@@ -83,6 +89,10 @@ class NodeRunFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeRunExceptionEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeInIterationFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
@@ -64,13 +64,21 @@ class Graph(BaseModel):
|
||||
edge_configs = graph_config.get("edges")
|
||||
if edge_configs is None:
|
||||
edge_configs = []
|
||||
# node configs
|
||||
node_configs = graph_config.get("nodes")
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
edge_configs = cast(list, edge_configs)
|
||||
node_configs = cast(list, node_configs)
|
||||
|
||||
# reorganize edges mapping
|
||||
edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||
target_edge_ids = set()
|
||||
fail_branch_source_node_id = [
|
||||
node["id"] for node in node_configs if node["data"].get("error_strategy") == "fail-branch"
|
||||
]
|
||||
for edge_config in edge_configs:
|
||||
source_node_id = edge_config.get("source")
|
||||
if not source_node_id:
|
||||
@@ -90,8 +98,16 @@ class Graph(BaseModel):
|
||||
|
||||
# parse run condition
|
||||
run_condition = None
|
||||
if edge_config.get("sourceHandle") and edge_config.get("sourceHandle") != "source":
|
||||
run_condition = RunCondition(type="branch_identify", branch_identify=edge_config.get("sourceHandle"))
|
||||
if edge_config.get("sourceHandle"):
|
||||
if (
|
||||
edge_config.get("source") in fail_branch_source_node_id
|
||||
and edge_config.get("sourceHandle") != "fail-branch"
|
||||
):
|
||||
run_condition = RunCondition(type="branch_identify", branch_identify="success-branch")
|
||||
elif edge_config.get("sourceHandle") != "source":
|
||||
run_condition = RunCondition(
|
||||
type="branch_identify", branch_identify=edge_config.get("sourceHandle")
|
||||
)
|
||||
|
||||
graph_edge = GraphEdge(
|
||||
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition
|
||||
@@ -100,13 +116,6 @@ class Graph(BaseModel):
|
||||
edge_mapping[source_node_id].append(graph_edge)
|
||||
reverse_edge_mapping[target_node_id].append(graph_edge)
|
||||
|
||||
# node configs
|
||||
node_configs = graph_config.get("nodes")
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
node_configs = cast(list, node_configs)
|
||||
|
||||
# fetch nodes that have no predecessor node
|
||||
root_node_configs = []
|
||||
all_node_id_config_mapping: dict[str, dict] = {}
|
||||
|
@@ -15,6 +15,7 @@ class RouteNodeState(BaseModel):
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
PAUSED = "paused"
|
||||
EXCEPTION = "exception"
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
"""node state id"""
|
||||
@@ -51,7 +52,11 @@ class RouteNodeState(BaseModel):
|
||||
|
||||
:param run_result: run result
|
||||
"""
|
||||
if self.status in {RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED}:
|
||||
if self.status in {
|
||||
RouteNodeState.Status.SUCCESS,
|
||||
RouteNodeState.Status.FAILED,
|
||||
RouteNodeState.Status.EXCEPTION,
|
||||
}:
|
||||
raise Exception(f"Route state {self.id} already finished")
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
@@ -59,6 +64,9 @@ class RouteNodeState(BaseModel):
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
self.status = RouteNodeState.Status.FAILED
|
||||
self.failed_reason = run_result.error
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
|
||||
self.status = RouteNodeState.Status.EXCEPTION
|
||||
self.failed_reason = run_result.error
|
||||
else:
|
||||
raise Exception(f"Invalid route status {run_result.status}")
|
||||
|
||||
|
@@ -5,21 +5,23 @@ import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from copy import copy, deepcopy
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseIterationEvent,
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunStartedEvent,
|
||||
@@ -36,7 +38,9 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from extensions.ext_database import db
|
||||
@@ -128,6 +132,7 @@ class GraphEngine:
|
||||
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
# trigger graph run start event
|
||||
yield GraphRunStartedEvent()
|
||||
handle_exceptions = []
|
||||
|
||||
try:
|
||||
if self.init_params.workflow_type == WorkflowType.CHAT:
|
||||
@@ -140,13 +145,17 @@ class GraphEngine:
|
||||
)
|
||||
|
||||
# run graph
|
||||
generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id))
|
||||
|
||||
generator = stream_processor.process(
|
||||
self._run(start_node_id=self.graph.root_node_id, handle_exceptions=handle_exceptions)
|
||||
)
|
||||
for item in generator:
|
||||
try:
|
||||
yield item
|
||||
if isinstance(item, NodeRunFailedEvent):
|
||||
yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or "Unknown error.")
|
||||
yield GraphRunFailedEvent(
|
||||
error=item.route_node_state.failed_reason or "Unknown error.",
|
||||
exceptions_count=len(handle_exceptions),
|
||||
)
|
||||
return
|
||||
elif isinstance(item, NodeRunSucceededEvent):
|
||||
if item.node_type == NodeType.END:
|
||||
@@ -172,19 +181,24 @@ class GraphEngine:
|
||||
].strip()
|
||||
except Exception as e:
|
||||
logger.exception("Graph run failed")
|
||||
yield GraphRunFailedEvent(error=str(e))
|
||||
yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions))
|
||||
return
|
||||
|
||||
# trigger graph run success event
|
||||
yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs)
|
||||
# count exceptions to determine partial success
|
||||
if len(handle_exceptions) > 0:
|
||||
yield GraphRunPartialSucceededEvent(
|
||||
exceptions_count=len(handle_exceptions), outputs=self.graph_runtime_state.outputs
|
||||
)
|
||||
else:
|
||||
# trigger graph run success event
|
||||
yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs)
|
||||
self._release_thread()
|
||||
except GraphRunFailedError as e:
|
||||
yield GraphRunFailedEvent(error=e.error)
|
||||
yield GraphRunFailedEvent(error=e.error, exceptions_count=len(handle_exceptions))
|
||||
self._release_thread()
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when graph running")
|
||||
yield GraphRunFailedEvent(error=str(e))
|
||||
yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions))
|
||||
self._release_thread()
|
||||
raise e
|
||||
|
||||
@@ -198,6 +212,7 @@ class GraphEngine:
|
||||
in_parallel_id: Optional[str] = None,
|
||||
parent_parallel_id: Optional[str] = None,
|
||||
parent_parallel_start_node_id: Optional[str] = None,
|
||||
handle_exceptions: list[str] = [],
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
parallel_start_node_id = None
|
||||
if in_parallel_id:
|
||||
@@ -242,7 +257,7 @@ class GraphEngine:
|
||||
previous_node_id=previous_node_id,
|
||||
thread_pool_id=self.thread_pool_id,
|
||||
)
|
||||
|
||||
node_instance = cast(BaseNode[BaseNodeData], node_instance)
|
||||
try:
|
||||
# run node
|
||||
generator = self._run_node(
|
||||
@@ -252,6 +267,7 @@ class GraphEngine:
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
handle_exceptions=handle_exceptions,
|
||||
)
|
||||
|
||||
for item in generator:
|
||||
@@ -301,7 +317,12 @@ class GraphEngine:
|
||||
|
||||
if len(edge_mappings) == 1:
|
||||
edge = edge_mappings[0]
|
||||
|
||||
if (
|
||||
previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
|
||||
and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
|
||||
and edge.run_condition is None
|
||||
):
|
||||
break
|
||||
if edge.run_condition:
|
||||
result = ConditionManager.get_condition_handler(
|
||||
init_params=self.init_params,
|
||||
@@ -334,7 +355,7 @@ class GraphEngine:
|
||||
if len(sub_edge_mappings) == 0:
|
||||
continue
|
||||
|
||||
edge = sub_edge_mappings[0]
|
||||
edge = cast(GraphEdge, sub_edge_mappings[0])
|
||||
|
||||
result = ConditionManager.get_condition_handler(
|
||||
init_params=self.init_params,
|
||||
@@ -355,6 +376,7 @@ class GraphEngine:
|
||||
edge_mappings=sub_edge_mappings,
|
||||
in_parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
handle_exceptions=handle_exceptions,
|
||||
)
|
||||
|
||||
for item in parallel_generator:
|
||||
@@ -369,11 +391,18 @@ class GraphEngine:
|
||||
break
|
||||
|
||||
next_node_id = final_node_id
|
||||
elif (
|
||||
node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
|
||||
and node_instance.should_continue_on_error
|
||||
and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
|
||||
):
|
||||
break
|
||||
else:
|
||||
parallel_generator = self._run_parallel_branches(
|
||||
edge_mappings=edge_mappings,
|
||||
in_parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
handle_exceptions=handle_exceptions,
|
||||
)
|
||||
|
||||
for item in parallel_generator:
|
||||
@@ -395,6 +424,7 @@ class GraphEngine:
|
||||
edge_mappings: list[GraphEdge],
|
||||
in_parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None,
|
||||
handle_exceptions: list[str] = [],
|
||||
) -> Generator[GraphEngineEvent | str, None, None]:
|
||||
# if nodes has no run conditions, parallel run all nodes
|
||||
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
|
||||
@@ -438,6 +468,7 @@ class GraphEngine:
|
||||
"parallel_start_node_id": edge.target_node_id,
|
||||
"parent_parallel_id": in_parallel_id,
|
||||
"parent_parallel_start_node_id": parallel_start_node_id,
|
||||
"handle_exceptions": handle_exceptions,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -481,6 +512,7 @@ class GraphEngine:
|
||||
parallel_start_node_id: str,
|
||||
parent_parallel_id: Optional[str] = None,
|
||||
parent_parallel_start_node_id: Optional[str] = None,
|
||||
handle_exceptions: list[str] = [],
|
||||
) -> None:
|
||||
"""
|
||||
Run parallel nodes
|
||||
@@ -502,6 +534,7 @@ class GraphEngine:
|
||||
in_parallel_id=parallel_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
handle_exceptions=handle_exceptions,
|
||||
)
|
||||
|
||||
for item in generator:
|
||||
@@ -548,6 +581,7 @@ class GraphEngine:
|
||||
parallel_start_node_id: Optional[str] = None,
|
||||
parent_parallel_id: Optional[str] = None,
|
||||
parent_parallel_start_node_id: Optional[str] = None,
|
||||
handle_exceptions: list[str] = [],
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Run node
|
||||
@@ -587,19 +621,55 @@ class GraphEngine:
|
||||
route_node_state.set_finished(run_result=run_result)
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
yield NodeRunFailedEvent(
|
||||
error=route_node_state.failed_reason or "Unknown error.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
if node_instance.should_continue_on_error:
|
||||
# if run failed, handle error
|
||||
run_result = self._handle_continue_on_error(
|
||||
node_instance,
|
||||
item.run_result,
|
||||
self.graph_runtime_state.variable_pool,
|
||||
handle_exceptions=handle_exceptions,
|
||||
)
|
||||
route_node_state.node_run_result = run_result
|
||||
route_node_state.status = RouteNodeState.Status.EXCEPTION
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
)
|
||||
yield NodeRunExceptionEvent(
|
||||
error=run_result.error or "System Error",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
else:
|
||||
yield NodeRunFailedEvent(
|
||||
error=route_node_state.failed_reason or "Unknown error.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
|
||||
node_instance.node_id
|
||||
):
|
||||
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
|
||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
# plus state total_tokens
|
||||
self.graph_runtime_state.total_tokens += int(
|
||||
@@ -735,6 +805,56 @@ class GraphEngine:
|
||||
new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool)
|
||||
return new_instance
|
||||
|
||||
def _handle_continue_on_error(
|
||||
self,
|
||||
node_instance: BaseNode[BaseNodeData],
|
||||
error_result: NodeRunResult,
|
||||
variable_pool: VariablePool,
|
||||
handle_exceptions: list[str] = [],
|
||||
) -> NodeRunResult:
|
||||
"""
|
||||
handle continue on error when self._should_continue_on_error is True
|
||||
|
||||
|
||||
:param error_result (NodeRunResult): error run result
|
||||
:param variable_pool (VariablePool): variable pool
|
||||
:return: excption run result
|
||||
"""
|
||||
# add error message and error type to variable pool
|
||||
variable_pool.add([node_instance.node_id, "error_message"], error_result.error)
|
||||
variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type)
|
||||
# add error message to handle_exceptions
|
||||
handle_exceptions.append(error_result.error)
|
||||
node_error_args = {
|
||||
"status": WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
"error": error_result.error,
|
||||
"inputs": error_result.inputs,
|
||||
"metadata": {
|
||||
NodeRunMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
|
||||
},
|
||||
}
|
||||
|
||||
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
|
||||
return NodeRunResult(
|
||||
**node_error_args,
|
||||
outputs={
|
||||
**node_instance.node_data.default_value_dict,
|
||||
"error_message": error_result.error,
|
||||
"error_type": error_result.error_type,
|
||||
},
|
||||
)
|
||||
elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH:
|
||||
if self.graph.edge_mapping.get(node_instance.node_id):
|
||||
node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED
|
||||
return NodeRunResult(
|
||||
**node_error_args,
|
||||
outputs={
|
||||
"error_message": error_result.error,
|
||||
"error_type": error_result.error_type,
|
||||
},
|
||||
)
|
||||
return error_result
|
||||
|
||||
|
||||
class GraphRunFailedError(Exception):
|
||||
def __init__(self, error: str):
|
||||
|
@@ -6,7 +6,7 @@ from core.workflow.nodes.answer.entities import (
|
||||
TextGenerateRouteChunk,
|
||||
VarGenerateRouteChunk,
|
||||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
@@ -148,13 +148,18 @@ class AnswerStreamGeneratorRouter:
|
||||
for edge in reverse_edges:
|
||||
source_node_id = edge.source_node_id
|
||||
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
||||
if source_node_type in {
|
||||
NodeType.ANSWER,
|
||||
NodeType.IF_ELSE,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
NodeType.ITERATION,
|
||||
NodeType.VARIABLE_ASSIGNER,
|
||||
}:
|
||||
source_node_data = node_id_config_mapping[source_node_id].get("data", {})
|
||||
if (
|
||||
source_node_type
|
||||
in {
|
||||
NodeType.ANSWER,
|
||||
NodeType.IF_ELSE,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
NodeType.ITERATION,
|
||||
NodeType.VARIABLE_ASSIGNER,
|
||||
}
|
||||
or source_node_data.get("error_strategy") == ErrorStrategy.FAIL_BRANCH
|
||||
):
|
||||
answer_dependencies[answer_node_id].append(source_node_id)
|
||||
else:
|
||||
cls._recursive_fetch_answer_dependencies(
|
||||
|
@@ -6,6 +6,7 @@ from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
@@ -50,7 +51,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
|
||||
for _ in stream_out_answer_node_ids:
|
||||
yield event
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
elif isinstance(event, NodeRunSucceededEvent | NodeRunExceptionEvent):
|
||||
yield event
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
# update self.route_position after all stream event finished
|
||||
|
@@ -1,14 +1,124 @@
|
||||
import json
|
||||
from abc import ABC
|
||||
from typing import Optional
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.workflow.nodes.base.exc import DefaultValueTypeError
|
||||
from core.workflow.nodes.enums import ErrorStrategy
|
||||
|
||||
|
||||
class DefaultValueType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
OBJECT = "object"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_FILES = "array[file]"
|
||||
|
||||
|
||||
NumberType = Union[int, float]
|
||||
|
||||
|
||||
class DefaultValue(BaseModel):
|
||||
value: Any
|
||||
type: DefaultValueType
|
||||
key: str
|
||||
|
||||
@staticmethod
|
||||
def _parse_json(value: str) -> Any:
|
||||
"""Unified JSON parsing handler"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
|
||||
|
||||
@staticmethod
|
||||
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
|
||||
"""Unified array type validation"""
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
|
||||
|
||||
@staticmethod
|
||||
def _convert_number(value: str) -> float:
|
||||
"""Unified number conversion handler"""
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_value_type(self) -> "DefaultValue":
|
||||
if self.type is None:
|
||||
raise DefaultValueTypeError("type field is required")
|
||||
|
||||
# Type validation configuration
|
||||
type_validators = {
|
||||
DefaultValueType.STRING: {
|
||||
"type": str,
|
||||
"converter": lambda x: x,
|
||||
},
|
||||
DefaultValueType.NUMBER: {
|
||||
"type": NumberType,
|
||||
"converter": self._convert_number,
|
||||
},
|
||||
DefaultValueType.OBJECT: {
|
||||
"type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_NUMBER: {
|
||||
"type": list,
|
||||
"element_type": NumberType,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_STRING: {
|
||||
"type": list,
|
||||
"element_type": str,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_OBJECT: {
|
||||
"type": list,
|
||||
"element_type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
}
|
||||
|
||||
validator = type_validators.get(self.type)
|
||||
if not validator:
|
||||
if self.type == DefaultValueType.ARRAY_FILES:
|
||||
# Handle files type
|
||||
return self
|
||||
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
|
||||
|
||||
# Handle string input cases
|
||||
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
|
||||
self.value = validator["converter"](self.value)
|
||||
|
||||
# Validate base type
|
||||
if not isinstance(self.value, validator["type"]):
|
||||
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
|
||||
|
||||
# Validate array element types
|
||||
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
|
||||
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
error_strategy: Optional[ErrorStrategy] = None
|
||||
default_value: Optional[list[DefaultValue]] = None
|
||||
version: str = "1"
|
||||
|
||||
@property
|
||||
def default_value_dict(self):
|
||||
if self.default_value:
|
||||
return {item.key: item.value for item in self.default_value}
|
||||
return {}
|
||||
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
start_node_id: Optional[str] = None
|
||||
|
10
api/core/workflow/nodes/base/exc.py
Normal file
10
api/core/workflow/nodes/base/exc.py
Normal file
@@ -0,0 +1,10 @@
|
||||
class BaseNodeError(Exception):
|
||||
"""Base class for node errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DefaultValueTypeError(BaseNodeError):
|
||||
"""Raised when the default value type is invalid."""
|
||||
|
||||
pass
|
@@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
@@ -72,10 +72,7 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
result = self._run()
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {self.node_id} failed to run")
|
||||
result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
result = NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=str(e), error_type="SystemError")
|
||||
|
||||
if isinstance(result, NodeRunResult):
|
||||
yield RunCompletedEvent(run_result=result)
|
||||
@@ -137,3 +134,12 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
:return:
|
||||
"""
|
||||
return self._node_type
|
||||
|
||||
@property
|
||||
def should_continue_on_error(self) -> bool:
|
||||
"""judge if should continue on error
|
||||
|
||||
Returns:
|
||||
bool: if should continue on error
|
||||
"""
|
||||
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
|
||||
|
@@ -61,7 +61,9 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
# Transform result
|
||||
result = self._transform_result(result, self.node_data.outputs)
|
||||
except (CodeExecutionError, CodeNodeError) as e:
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
|
||||
)
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
|
||||
|
||||
|
@@ -22,3 +22,16 @@ class NodeType(StrEnum):
|
||||
VARIABLE_ASSIGNER = "assigner"
|
||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||
LIST_OPERATOR = "list-operator"
|
||||
|
||||
|
||||
class ErrorStrategy(StrEnum):
|
||||
FAIL_BRANCH = "fail-branch"
|
||||
DEFAULT_VALUE = "default-value"
|
||||
|
||||
|
||||
class FailBranchSourceHandle(StrEnum):
|
||||
FAILED = "fail-branch"
|
||||
SUCCESS = "success-branch"
|
||||
|
||||
|
||||
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
|
||||
|
@@ -21,6 +21,7 @@ from .entities import (
|
||||
from .exc import (
|
||||
AuthorizationConfigError,
|
||||
FileFetchError,
|
||||
HttpRequestNodeError,
|
||||
InvalidHttpMethodError,
|
||||
ResponseSizeError,
|
||||
)
|
||||
@@ -208,8 +209,10 @@ class Executor:
|
||||
"follow_redirects": True,
|
||||
}
|
||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||
|
||||
response = getattr(ssrf_proxy, self.method)(**request_args)
|
||||
try:
|
||||
response = getattr(ssrf_proxy, self.method)(**request_args)
|
||||
except ssrf_proxy.MaxRetriesExceededError as e:
|
||||
raise HttpRequestNodeError(str(e))
|
||||
return response
|
||||
|
||||
def invoke(self) -> Response:
|
||||
|
@@ -65,6 +65,21 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
|
||||
response = http_executor.invoke()
|
||||
files = self.extract_files(url=http_executor.url, response=response)
|
||||
if not response.response.is_success and self.should_continue_on_error:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
outputs={
|
||||
"status_code": response.status_code,
|
||||
"body": response.text if not files else "",
|
||||
"headers": response.headers,
|
||||
"files": files,
|
||||
},
|
||||
process_data={
|
||||
"request": http_executor.to_log(),
|
||||
},
|
||||
error=f"Request failed with status code {response.status_code}",
|
||||
error_type="HTTPResponseCodeError",
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
@@ -83,6 +98,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
process_data=process_data,
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@@ -193,6 +193,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
error=str(e),
|
||||
inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
@@ -139,7 +139,7 @@ class QuestionClassifierNode(LLMNode):
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
outputs = {"class_name": category_name}
|
||||
outputs = {"class_name": category_name, "class_id": category_id}
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
|
@@ -56,6 +56,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info,
|
||||
},
|
||||
error=f"Failed to get tool runtime: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
|
||||
# get parameters
|
||||
@@ -89,6 +90,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info,
|
||||
},
|
||||
error=f"Failed to invoke tool: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
|
||||
# convert tool messages
|
||||
|
Reference in New Issue
Block a user