chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -8,19 +8,13 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta
|
||||
|
||||
|
||||
class RunConditionHandler(ABC):
|
||||
def __init__(self,
|
||||
init_params: GraphInitParams,
|
||||
graph: Graph,
|
||||
condition: RunCondition):
|
||||
def __init__(self, init_params: GraphInitParams, graph: Graph, condition: RunCondition):
|
||||
self.init_params = init_params
|
||||
self.graph = graph
|
||||
self.condition = condition
|
||||
|
||||
@abstractmethod
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_route_node_state: RouteNodeState
|
||||
) -> bool:
|
||||
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
|
@@ -4,10 +4,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta
|
||||
|
||||
|
||||
class BranchIdentifyRunConditionHandler(RunConditionHandler):
|
||||
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_route_node_state: RouteNodeState) -> bool:
|
||||
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
|
@@ -5,10 +5,7 @@ from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
||||
|
||||
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_route_node_state: RouteNodeState
|
||||
) -> bool:
|
||||
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
@@ -22,8 +19,7 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||
# process condition
|
||||
condition_processor = ConditionProcessor()
|
||||
input_conditions, group_result = condition_processor.process_conditions(
|
||||
variable_pool=graph_runtime_state.variable_pool,
|
||||
conditions=self.condition.conditions
|
||||
variable_pool=graph_runtime_state.variable_pool, conditions=self.condition.conditions
|
||||
)
|
||||
|
||||
# Apply the logical operator for the current case
|
||||
|
@@ -9,9 +9,7 @@ from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
class ConditionManager:
|
||||
@staticmethod
|
||||
def get_condition_handler(
|
||||
init_params: GraphInitParams,
|
||||
graph: Graph,
|
||||
run_condition: RunCondition
|
||||
init_params: GraphInitParams, graph: Graph, run_condition: RunCondition
|
||||
) -> RunConditionHandler:
|
||||
"""
|
||||
Get condition handler
|
||||
@@ -22,14 +20,6 @@ class ConditionManager:
|
||||
:return: condition handler
|
||||
"""
|
||||
if run_condition.type == "branch_identify":
|
||||
return BranchIdentifyRunConditionHandler(
|
||||
init_params=init_params,
|
||||
graph=graph,
|
||||
condition=run_condition
|
||||
)
|
||||
return BranchIdentifyRunConditionHandler(init_params=init_params, graph=graph, condition=run_condition)
|
||||
else:
|
||||
return ConditionRunConditionHandlerHandler(
|
||||
init_params=init_params,
|
||||
graph=graph,
|
||||
condition=run_condition
|
||||
)
|
||||
return ConditionRunConditionHandlerHandler(init_params=init_params, graph=graph, condition=run_condition)
|
||||
|
@@ -34,38 +34,25 @@ class Graph(BaseModel):
|
||||
root_node_id: str = Field(..., description="root node id of the graph")
|
||||
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
|
||||
node_id_config_mapping: dict[str, dict] = Field(
|
||||
default_factory=list,
|
||||
description="node configs mapping (node id: node config)"
|
||||
default_factory=list, description="node configs mapping (node id: node config)"
|
||||
)
|
||||
edge_mapping: dict[str, list[GraphEdge]] = Field(
|
||||
default_factory=dict,
|
||||
description="graph edge mapping (source node id: edges)"
|
||||
default_factory=dict, description="graph edge mapping (source node id: edges)"
|
||||
)
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]] = Field(
|
||||
default_factory=dict,
|
||||
description="reverse graph edge mapping (target node id: edges)"
|
||||
default_factory=dict, description="reverse graph edge mapping (target node id: edges)"
|
||||
)
|
||||
parallel_mapping: dict[str, GraphParallel] = Field(
|
||||
default_factory=dict,
|
||||
description="graph parallel mapping (parallel id: parallel)"
|
||||
default_factory=dict, description="graph parallel mapping (parallel id: parallel)"
|
||||
)
|
||||
node_parallel_mapping: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="graph node parallel mapping (node id: parallel id)"
|
||||
)
|
||||
answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(
|
||||
...,
|
||||
description="answer stream generate routes"
|
||||
)
|
||||
end_stream_param: EndStreamParam = Field(
|
||||
...,
|
||||
description="end stream param"
|
||||
default_factory=dict, description="graph node parallel mapping (node id: parallel id)"
|
||||
)
|
||||
answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(..., description="answer stream generate routes")
|
||||
end_stream_param: EndStreamParam = Field(..., description="end stream param")
|
||||
|
||||
@classmethod
|
||||
def init(cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
root_node_id: Optional[str] = None) -> "Graph":
|
||||
def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = None) -> "Graph":
|
||||
"""
|
||||
Init graph
|
||||
|
||||
@@ -74,7 +61,7 @@ class Graph(BaseModel):
|
||||
:return: graph
|
||||
"""
|
||||
# edge configs
|
||||
edge_configs = graph_config.get('edges')
|
||||
edge_configs = graph_config.get("edges")
|
||||
if edge_configs is None:
|
||||
edge_configs = []
|
||||
|
||||
@@ -85,14 +72,14 @@ class Graph(BaseModel):
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||
target_edge_ids = set()
|
||||
for edge_config in edge_configs:
|
||||
source_node_id = edge_config.get('source')
|
||||
source_node_id = edge_config.get("source")
|
||||
if not source_node_id:
|
||||
continue
|
||||
|
||||
if source_node_id not in edge_mapping:
|
||||
edge_mapping[source_node_id] = []
|
||||
|
||||
target_node_id = edge_config.get('target')
|
||||
target_node_id = edge_config.get("target")
|
||||
if not target_node_id:
|
||||
continue
|
||||
|
||||
@@ -107,23 +94,18 @@ class Graph(BaseModel):
|
||||
|
||||
# parse run condition
|
||||
run_condition = None
|
||||
if edge_config.get('sourceHandle') and edge_config.get('sourceHandle') != 'source':
|
||||
run_condition = RunCondition(
|
||||
type='branch_identify',
|
||||
branch_identify=edge_config.get('sourceHandle')
|
||||
)
|
||||
if edge_config.get("sourceHandle") and edge_config.get("sourceHandle") != "source":
|
||||
run_condition = RunCondition(type="branch_identify", branch_identify=edge_config.get("sourceHandle"))
|
||||
|
||||
graph_edge = GraphEdge(
|
||||
source_node_id=source_node_id,
|
||||
target_node_id=target_node_id,
|
||||
run_condition=run_condition
|
||||
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition
|
||||
)
|
||||
|
||||
edge_mapping[source_node_id].append(graph_edge)
|
||||
reverse_edge_mapping[target_node_id].append(graph_edge)
|
||||
|
||||
# node configs
|
||||
node_configs = graph_config.get('nodes')
|
||||
node_configs = graph_config.get("nodes")
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
@@ -133,7 +115,7 @@ class Graph(BaseModel):
|
||||
root_node_configs = []
|
||||
all_node_id_config_mapping: dict[str, dict] = {}
|
||||
for node_config in node_configs:
|
||||
node_id = node_config.get('id')
|
||||
node_id = node_config.get("id")
|
||||
if not node_id:
|
||||
continue
|
||||
|
||||
@@ -142,30 +124,29 @@ class Graph(BaseModel):
|
||||
|
||||
all_node_id_config_mapping[node_id] = node_config
|
||||
|
||||
root_node_ids = [node_config.get('id') for node_config in root_node_configs]
|
||||
root_node_ids = [node_config.get("id") for node_config in root_node_configs]
|
||||
|
||||
# fetch root node
|
||||
if not root_node_id:
|
||||
# if no root node id, use the START type node as root node
|
||||
root_node_id = next((node_config.get("id") for node_config in root_node_configs
|
||||
if node_config.get('data', {}).get('type', '') == NodeType.START.value), None)
|
||||
root_node_id = next(
|
||||
(
|
||||
node_config.get("id")
|
||||
for node_config in root_node_configs
|
||||
if node_config.get("data", {}).get("type", "") == NodeType.START.value
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not root_node_id or root_node_id not in root_node_ids:
|
||||
raise ValueError(f"Root node id {root_node_id} not found in the graph")
|
||||
|
||||
|
||||
# Check whether it is connected to the previous node
|
||||
cls._check_connected_to_previous_node(
|
||||
route=[root_node_id],
|
||||
edge_mapping=edge_mapping
|
||||
)
|
||||
cls._check_connected_to_previous_node(route=[root_node_id], edge_mapping=edge_mapping)
|
||||
|
||||
# fetch all node ids from root node
|
||||
node_ids = [root_node_id]
|
||||
cls._recursively_add_node_ids(
|
||||
node_ids=node_ids,
|
||||
edge_mapping=edge_mapping,
|
||||
node_id=root_node_id
|
||||
)
|
||||
cls._recursively_add_node_ids(node_ids=node_ids, edge_mapping=edge_mapping, node_id=root_node_id)
|
||||
|
||||
node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids}
|
||||
|
||||
@@ -177,29 +158,26 @@ class Graph(BaseModel):
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=root_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
)
|
||||
|
||||
# Check if it exceeds N layers of parallel
|
||||
for parallel in parallel_mapping.values():
|
||||
if parallel.parent_parallel_id:
|
||||
cls._check_exceed_parallel_limit(
|
||||
parallel_mapping=parallel_mapping,
|
||||
level_limit=3,
|
||||
parent_parallel_id=parallel.parent_parallel_id
|
||||
parallel_mapping=parallel_mapping, level_limit=3, parent_parallel_id=parallel.parent_parallel_id
|
||||
)
|
||||
|
||||
# init answer stream generate routes
|
||||
answer_stream_generate_routes = AnswerStreamGeneratorRouter.init(
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping
|
||||
node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping
|
||||
)
|
||||
|
||||
# init end stream param
|
||||
end_stream_param = EndStreamGeneratorRouter.init(
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
)
|
||||
|
||||
# init graph
|
||||
@@ -212,14 +190,14 @@ class Graph(BaseModel):
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
answer_stream_generate_routes=answer_stream_generate_routes,
|
||||
end_stream_param=end_stream_param
|
||||
end_stream_param=end_stream_param,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
def add_extra_edge(self, source_node_id: str,
|
||||
target_node_id: str,
|
||||
run_condition: Optional[RunCondition] = None) -> None:
|
||||
def add_extra_edge(
|
||||
self, source_node_id: str, target_node_id: str, run_condition: Optional[RunCondition] = None
|
||||
) -> None:
|
||||
"""
|
||||
Add extra edge to the graph
|
||||
|
||||
@@ -237,9 +215,7 @@ class Graph(BaseModel):
|
||||
return
|
||||
|
||||
graph_edge = GraphEdge(
|
||||
source_node_id=source_node_id,
|
||||
target_node_id=target_node_id,
|
||||
run_condition=run_condition
|
||||
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition
|
||||
)
|
||||
|
||||
self.edge_mapping[source_node_id].append(graph_edge)
|
||||
@@ -254,17 +230,18 @@ class Graph(BaseModel):
|
||||
for node_id in self.node_ids:
|
||||
if node_id not in self.edge_mapping:
|
||||
leaf_node_ids.append(node_id)
|
||||
elif (len(self.edge_mapping[node_id]) == 1
|
||||
and self.edge_mapping[node_id][0].target_node_id == self.root_node_id):
|
||||
elif (
|
||||
len(self.edge_mapping[node_id]) == 1
|
||||
and self.edge_mapping[node_id][0].target_node_id == self.root_node_id
|
||||
):
|
||||
leaf_node_ids.append(node_id)
|
||||
|
||||
return leaf_node_ids
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_node_ids(cls,
|
||||
node_ids: list[str],
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
node_id: str) -> None:
|
||||
def _recursively_add_node_ids(
|
||||
cls, node_ids: list[str], edge_mapping: dict[str, list[GraphEdge]], node_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Recursively add node ids
|
||||
|
||||
@@ -278,17 +255,11 @@ class Graph(BaseModel):
|
||||
|
||||
node_ids.append(graph_edge.target_node_id)
|
||||
cls._recursively_add_node_ids(
|
||||
node_ids=node_ids,
|
||||
edge_mapping=edge_mapping,
|
||||
node_id=graph_edge.target_node_id
|
||||
node_ids=node_ids, edge_mapping=edge_mapping, node_id=graph_edge.target_node_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _check_connected_to_previous_node(
|
||||
cls,
|
||||
route: list[str],
|
||||
edge_mapping: dict[str, list[GraphEdge]]
|
||||
) -> None:
|
||||
def _check_connected_to_previous_node(cls, route: list[str], edge_mapping: dict[str, list[GraphEdge]]) -> None:
|
||||
"""
|
||||
Check whether it is connected to the previous node
|
||||
"""
|
||||
@@ -299,7 +270,9 @@ class Graph(BaseModel):
|
||||
continue
|
||||
|
||||
if graph_edge.target_node_id in route:
|
||||
raise ValueError(f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph.")
|
||||
raise ValueError(
|
||||
f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph."
|
||||
)
|
||||
|
||||
new_route = route[:]
|
||||
new_route.append(graph_edge.target_node_id)
|
||||
@@ -316,7 +289,7 @@ class Graph(BaseModel):
|
||||
start_node_id: str,
|
||||
parallel_mapping: dict[str, GraphParallel],
|
||||
node_parallel_mapping: dict[str, str],
|
||||
parent_parallel: Optional[GraphParallel] = None
|
||||
parent_parallel: Optional[GraphParallel] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Recursively add parallel ids
|
||||
@@ -355,14 +328,14 @@ class Graph(BaseModel):
|
||||
parallel = GraphParallel(
|
||||
start_from_node_id=start_node_id,
|
||||
parent_parallel_id=parent_parallel.id if parent_parallel else None,
|
||||
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None
|
||||
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
|
||||
)
|
||||
parallel_mapping[parallel.id] = parallel
|
||||
|
||||
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
parallel_branch_node_ids=parallel_branch_node_ids
|
||||
parallel_branch_node_ids=parallel_branch_node_ids,
|
||||
)
|
||||
|
||||
# collect all branches node ids
|
||||
@@ -403,14 +376,25 @@ class Graph(BaseModel):
|
||||
continue
|
||||
|
||||
if (
|
||||
(node_parallel_mapping.get(target_node_id) and node_parallel_mapping.get(target_node_id) == parent_parallel_id)
|
||||
or (parent_parallel and parent_parallel.end_to_node_id and target_node_id == parent_parallel.end_to_node_id)
|
||||
(
|
||||
node_parallel_mapping.get(target_node_id)
|
||||
and node_parallel_mapping.get(target_node_id) == parent_parallel_id
|
||||
)
|
||||
or (
|
||||
parent_parallel
|
||||
and parent_parallel.end_to_node_id
|
||||
and target_node_id == parent_parallel.end_to_node_id
|
||||
)
|
||||
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
|
||||
):
|
||||
outside_parallel_target_node_ids.add(target_node_id)
|
||||
|
||||
if len(outside_parallel_target_node_ids) == 1:
|
||||
if parent_parallel and parent_parallel.end_to_node_id and parallel.end_to_node_id == parent_parallel.end_to_node_id:
|
||||
if (
|
||||
parent_parallel
|
||||
and parent_parallel.end_to_node_id
|
||||
and parallel.end_to_node_id == parent_parallel.end_to_node_id
|
||||
):
|
||||
parallel.end_to_node_id = None
|
||||
else:
|
||||
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
|
||||
@@ -420,18 +404,20 @@ class Graph(BaseModel):
|
||||
if parallel:
|
||||
current_parallel = parallel
|
||||
elif parent_parallel:
|
||||
if not parent_parallel.end_to_node_id or (parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id):
|
||||
if not parent_parallel.end_to_node_id or (
|
||||
parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id
|
||||
):
|
||||
current_parallel = parent_parallel
|
||||
else:
|
||||
# fetch parent parallel's parent parallel
|
||||
parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id
|
||||
if parent_parallel_parent_parallel_id:
|
||||
parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id)
|
||||
if (
|
||||
parent_parallel_parent_parallel
|
||||
and (
|
||||
not parent_parallel_parent_parallel.end_to_node_id
|
||||
or (parent_parallel_parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id)
|
||||
if parent_parallel_parent_parallel and (
|
||||
not parent_parallel_parent_parallel.end_to_node_id
|
||||
or (
|
||||
parent_parallel_parent_parallel.end_to_node_id
|
||||
and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id
|
||||
)
|
||||
):
|
||||
current_parallel = parent_parallel_parent_parallel
|
||||
@@ -442,7 +428,7 @@ class Graph(BaseModel):
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
parent_parallel=current_parallel
|
||||
parent_parallel=current_parallel,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -451,7 +437,7 @@ class Graph(BaseModel):
|
||||
parallel_mapping: dict[str, GraphParallel],
|
||||
level_limit: int,
|
||||
parent_parallel_id: str,
|
||||
current_level: int = 1
|
||||
current_level: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Check if it exceeds N layers of parallel
|
||||
@@ -459,25 +445,27 @@ class Graph(BaseModel):
|
||||
parent_parallel = parallel_mapping.get(parent_parallel_id)
|
||||
if not parent_parallel:
|
||||
return
|
||||
|
||||
|
||||
current_level += 1
|
||||
if current_level > level_limit:
|
||||
raise ValueError(f"Exceeds {level_limit} layers of parallel")
|
||||
|
||||
|
||||
if parent_parallel.parent_parallel_id:
|
||||
cls._check_exceed_parallel_limit(
|
||||
parallel_mapping=parallel_mapping,
|
||||
level_limit=level_limit,
|
||||
parent_parallel_id=parent_parallel.parent_parallel_id,
|
||||
current_level=current_level
|
||||
current_level=current_level,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_parallel_node_ids(cls,
|
||||
branch_node_ids: list[str],
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
merge_node_id: str,
|
||||
start_node_id: str) -> None:
|
||||
def _recursively_add_parallel_node_ids(
|
||||
cls,
|
||||
branch_node_ids: list[str],
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
merge_node_id: str,
|
||||
start_node_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Recursively add node ids
|
||||
|
||||
@@ -487,21 +475,22 @@ class Graph(BaseModel):
|
||||
:param start_node_id: start node id
|
||||
"""
|
||||
for graph_edge in edge_mapping.get(start_node_id, []):
|
||||
if (graph_edge.target_node_id != merge_node_id
|
||||
and graph_edge.target_node_id not in branch_node_ids):
|
||||
if graph_edge.target_node_id != merge_node_id and graph_edge.target_node_id not in branch_node_ids:
|
||||
branch_node_ids.append(graph_edge.target_node_id)
|
||||
cls._recursively_add_parallel_node_ids(
|
||||
branch_node_ids=branch_node_ids,
|
||||
edge_mapping=edge_mapping,
|
||||
merge_node_id=merge_node_id,
|
||||
start_node_id=graph_edge.target_node_id
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _fetch_all_node_ids_in_parallels(cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||
parallel_branch_node_ids: list[str]) -> dict[str, list[str]]:
|
||||
def _fetch_all_node_ids_in_parallels(
|
||||
cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||
parallel_branch_node_ids: list[str],
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
Fetch all node ids in parallels
|
||||
"""
|
||||
@@ -513,7 +502,7 @@ class Graph(BaseModel):
|
||||
cls._recursively_fetch_routes(
|
||||
edge_mapping=edge_mapping,
|
||||
start_node_id=parallel_branch_node_id,
|
||||
routes_node_ids=routes_node_ids[parallel_branch_node_id]
|
||||
routes_node_ids=routes_node_ids[parallel_branch_node_id],
|
||||
)
|
||||
|
||||
# fetch leaf node ids from routes node ids
|
||||
@@ -529,13 +518,13 @@ class Graph(BaseModel):
|
||||
|
||||
for branch_node_id2, inner_route2 in routes_node_ids.items():
|
||||
if (
|
||||
branch_node_id != branch_node_id2
|
||||
branch_node_id != branch_node_id2
|
||||
and node_id in inner_route2
|
||||
and len(reverse_edge_mapping.get(node_id, [])) > 1
|
||||
and cls._is_node_in_routes(
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=node_id,
|
||||
routes_node_ids=routes_node_ids
|
||||
routes_node_ids=routes_node_ids,
|
||||
)
|
||||
):
|
||||
if node_id not in merge_branch_node_ids:
|
||||
@@ -551,23 +540,18 @@ class Graph(BaseModel):
|
||||
for node_id, branch_node_ids in merge_branch_node_ids.items():
|
||||
for node_id2, branch_node_ids2 in merge_branch_node_ids.items():
|
||||
if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2):
|
||||
if (node_id, node_id2) not in duplicate_end_node_ids and (node_id2, node_id) not in duplicate_end_node_ids:
|
||||
if (node_id, node_id2) not in duplicate_end_node_ids and (
|
||||
node_id2,
|
||||
node_id,
|
||||
) not in duplicate_end_node_ids:
|
||||
duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids
|
||||
|
||||
|
||||
for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
|
||||
# check which node is after
|
||||
if cls._is_node2_after_node1(
|
||||
node1_id=node_id,
|
||||
node2_id=node_id2,
|
||||
edge_mapping=edge_mapping
|
||||
):
|
||||
if cls._is_node2_after_node1(node1_id=node_id, node2_id=node_id2, edge_mapping=edge_mapping):
|
||||
if node_id in merge_branch_node_ids:
|
||||
del merge_branch_node_ids[node_id2]
|
||||
elif cls._is_node2_after_node1(
|
||||
node1_id=node_id2,
|
||||
node2_id=node_id,
|
||||
edge_mapping=edge_mapping
|
||||
):
|
||||
elif cls._is_node2_after_node1(node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping):
|
||||
if node_id2 in merge_branch_node_ids:
|
||||
del merge_branch_node_ids[node_id]
|
||||
|
||||
@@ -599,16 +583,15 @@ class Graph(BaseModel):
|
||||
branch_node_ids=in_branch_node_ids[branch_node_id],
|
||||
edge_mapping=edge_mapping,
|
||||
merge_node_id=merge_node_id,
|
||||
start_node_id=branch_node_id
|
||||
start_node_id=branch_node_id,
|
||||
)
|
||||
|
||||
return in_branch_node_ids
|
||||
|
||||
@classmethod
|
||||
def _recursively_fetch_routes(cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
start_node_id: str,
|
||||
routes_node_ids: list[str]) -> None:
|
||||
def _recursively_fetch_routes(
|
||||
cls, edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: list[str]
|
||||
) -> None:
|
||||
"""
|
||||
Recursively fetch route
|
||||
"""
|
||||
@@ -621,28 +604,25 @@ class Graph(BaseModel):
|
||||
routes_node_ids.append(graph_edge.target_node_id)
|
||||
|
||||
cls._recursively_fetch_routes(
|
||||
edge_mapping=edge_mapping,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
routes_node_ids=routes_node_ids
|
||||
edge_mapping=edge_mapping, start_node_id=graph_edge.target_node_id, routes_node_ids=routes_node_ids
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _is_node_in_routes(cls,
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||
start_node_id: str,
|
||||
routes_node_ids: dict[str, list[str]]) -> bool:
|
||||
def _is_node_in_routes(
|
||||
cls, reverse_edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: dict[str, list[str]]
|
||||
) -> bool:
|
||||
"""
|
||||
Recursively check if the node is in the routes
|
||||
"""
|
||||
if start_node_id not in reverse_edge_mapping:
|
||||
return False
|
||||
|
||||
|
||||
all_routes_node_ids = set()
|
||||
parallel_start_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
all_routes_node_ids.add(node_id)
|
||||
|
||||
|
||||
if branch_node_id in reverse_edge_mapping:
|
||||
for graph_edge in reverse_edge_mapping[branch_node_id]:
|
||||
if graph_edge.source_node_id not in parallel_start_node_ids:
|
||||
@@ -655,38 +635,34 @@ class Graph(BaseModel):
|
||||
if set(branch_node_ids) == set(routes_node_ids.keys()):
|
||||
parallel_start_node_id = p_start_node_id
|
||||
return True
|
||||
|
||||
|
||||
if not parallel_start_node_id:
|
||||
raise Exception("Parallel start node id not found")
|
||||
|
||||
|
||||
for graph_edge in reverse_edge_mapping[start_node_id]:
|
||||
if graph_edge.source_node_id not in all_routes_node_ids or graph_edge.source_node_id != parallel_start_node_id:
|
||||
if (
|
||||
graph_edge.source_node_id not in all_routes_node_ids
|
||||
or graph_edge.source_node_id != parallel_start_node_id
|
||||
):
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _is_node2_after_node1(
|
||||
cls,
|
||||
node1_id: str,
|
||||
node2_id: str,
|
||||
edge_mapping: dict[str, list[GraphEdge]]
|
||||
) -> bool:
|
||||
def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool:
|
||||
"""
|
||||
is node2 after node1
|
||||
"""
|
||||
if node1_id not in edge_mapping:
|
||||
return False
|
||||
|
||||
|
||||
for graph_edge in edge_mapping[node1_id]:
|
||||
if graph_edge.target_node_id == node2_id:
|
||||
return True
|
||||
|
||||
|
||||
if cls._is_node2_after_node1(
|
||||
node1_id=graph_edge.target_node_id,
|
||||
node2_id=node2_id,
|
||||
edge_mapping=edge_mapping
|
||||
node1_id=graph_edge.target_node_id, node2_id=node2_id, edge_mapping=edge_mapping
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return False
|
||||
|
@@ -10,7 +10,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRoute
|
||||
class GraphRuntimeState(BaseModel):
|
||||
variable_pool: VariablePool = Field(..., description="variable pool")
|
||||
"""variable pool"""
|
||||
|
||||
|
||||
start_at: float = Field(..., description="start time")
|
||||
"""start time"""
|
||||
total_tokens: int = 0
|
||||
|
@@ -18,4 +18,4 @@ class RunCondition(BaseModel):
|
||||
|
||||
@property
|
||||
def hash(self) -> str:
|
||||
return hashlib.sha256(self.model_dump_json().encode()).hexdigest()
|
||||
return hashlib.sha256(self.model_dump_json().encode()).hexdigest()
|
||||
|
@@ -68,13 +68,11 @@ class RouteNodeState(BaseModel):
|
||||
|
||||
class RuntimeRouteState(BaseModel):
|
||||
routes: dict[str, list[str]] = Field(
|
||||
default_factory=dict,
|
||||
description="graph state routes (source_node_state_id: target_node_state_id)"
|
||||
default_factory=dict, description="graph state routes (source_node_state_id: target_node_state_id)"
|
||||
)
|
||||
|
||||
node_state_mapping: dict[str, RouteNodeState] = Field(
|
||||
default_factory=dict,
|
||||
description="node state mapping (route_node_state_id: route_node_state)"
|
||||
default_factory=dict, description="node state mapping (route_node_state_id: route_node_state)"
|
||||
)
|
||||
|
||||
def create_node_state(self, node_id: str) -> RouteNodeState:
|
||||
@@ -99,13 +97,13 @@ class RuntimeRouteState(BaseModel):
|
||||
|
||||
self.routes[source_node_state_id].append(target_node_state_id)
|
||||
|
||||
def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) \
|
||||
-> list[RouteNodeState]:
|
||||
def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) -> list[RouteNodeState]:
|
||||
"""
|
||||
Get routes with node state by source node id
|
||||
|
||||
:param source_node_state_id: source node state id
|
||||
:return: routes with node state
|
||||
"""
|
||||
return [self.node_state_mapping[target_state_id]
|
||||
for target_state_id in self.routes.get(source_node_state_id, [])]
|
||||
return [
|
||||
self.node_state_mapping[target_state_id] for target_state_id in self.routes.get(source_node_state_id, [])
|
||||
]
|
||||
|
@@ -48,8 +48,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GraphEngineThreadPool(ThreadPoolExecutor):
|
||||
def __init__(self, max_workers=None, thread_name_prefix='',
|
||||
initializer=None, initargs=(), max_submit_count=100) -> None:
|
||||
def __init__(
|
||||
self, max_workers=None, thread_name_prefix="", initializer=None, initargs=(), max_submit_count=100
|
||||
) -> None:
|
||||
super().__init__(max_workers, thread_name_prefix, initializer, initargs)
|
||||
self.max_submit_count = max_submit_count
|
||||
self.submit_count = 0
|
||||
@@ -57,9 +58,9 @@ class GraphEngineThreadPool(ThreadPoolExecutor):
|
||||
def submit(self, fn, *args, **kwargs):
|
||||
self.submit_count += 1
|
||||
self.check_is_full()
|
||||
|
||||
|
||||
return super().submit(fn, *args, **kwargs)
|
||||
|
||||
|
||||
def check_is_full(self) -> None:
|
||||
print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}")
|
||||
if self.submit_count > self.max_submit_count:
|
||||
@@ -70,21 +71,21 @@ class GraphEngine:
|
||||
workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_type: WorkflowType,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
call_depth: int,
|
||||
graph: Graph,
|
||||
graph_config: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
max_execution_steps: int,
|
||||
max_execution_time: int,
|
||||
thread_pool_id: Optional[str] = None
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_type: WorkflowType,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
call_depth: int,
|
||||
graph: Graph,
|
||||
graph_config: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
max_execution_steps: int,
|
||||
max_execution_time: int,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
thread_pool_max_submit_count = 100
|
||||
thread_pool_max_workers = 10
|
||||
@@ -93,12 +94,14 @@ class GraphEngine:
|
||||
if thread_pool_id:
|
||||
if not thread_pool_id in GraphEngine.workflow_thread_pool_mapping:
|
||||
raise ValueError(f"Max submit count {thread_pool_max_submit_count} of workflow thread pool reached.")
|
||||
|
||||
|
||||
self.thread_pool_id = thread_pool_id
|
||||
self.thread_pool = GraphEngine.workflow_thread_pool_mapping[thread_pool_id]
|
||||
self.is_main_thread_pool = False
|
||||
else:
|
||||
self.thread_pool = GraphEngineThreadPool(max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count)
|
||||
self.thread_pool = GraphEngineThreadPool(
|
||||
max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count
|
||||
)
|
||||
self.thread_pool_id = str(uuid.uuid4())
|
||||
self.is_main_thread_pool = True
|
||||
GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool
|
||||
@@ -113,13 +116,10 @@ class GraphEngine:
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth
|
||||
call_depth=call_depth,
|
||||
)
|
||||
|
||||
self.graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
self.graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
self.max_execution_steps = max_execution_steps
|
||||
self.max_execution_time = max_execution_time
|
||||
@@ -136,37 +136,40 @@ class GraphEngine:
|
||||
stream_processor_cls = EndStreamProcessor
|
||||
|
||||
stream_processor = stream_processor_cls(
|
||||
graph=self.graph,
|
||||
variable_pool=self.graph_runtime_state.variable_pool
|
||||
graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool
|
||||
)
|
||||
|
||||
# run graph
|
||||
generator = stream_processor.process(
|
||||
self._run(start_node_id=self.graph.root_node_id)
|
||||
)
|
||||
generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id))
|
||||
|
||||
for item in generator:
|
||||
try:
|
||||
yield item
|
||||
if isinstance(item, NodeRunFailedEvent):
|
||||
yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or 'Unknown error.')
|
||||
yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or "Unknown error.")
|
||||
return
|
||||
elif isinstance(item, NodeRunSucceededEvent):
|
||||
if item.node_type == NodeType.END:
|
||||
self.graph_runtime_state.outputs = (item.route_node_state.node_run_result.outputs
|
||||
if item.route_node_state.node_run_result
|
||||
and item.route_node_state.node_run_result.outputs
|
||||
else {})
|
||||
self.graph_runtime_state.outputs = (
|
||||
item.route_node_state.node_run_result.outputs
|
||||
if item.route_node_state.node_run_result
|
||||
and item.route_node_state.node_run_result.outputs
|
||||
else {}
|
||||
)
|
||||
elif item.node_type == NodeType.ANSWER:
|
||||
if "answer" not in self.graph_runtime_state.outputs:
|
||||
self.graph_runtime_state.outputs["answer"] = ""
|
||||
|
||||
self.graph_runtime_state.outputs["answer"] += "\n" + (item.route_node_state.node_run_result.outputs.get("answer", "")
|
||||
if item.route_node_state.node_run_result
|
||||
and item.route_node_state.node_run_result.outputs
|
||||
else "")
|
||||
|
||||
self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs["answer"].strip()
|
||||
self.graph_runtime_state.outputs["answer"] += "\n" + (
|
||||
item.route_node_state.node_run_result.outputs.get("answer", "")
|
||||
if item.route_node_state.node_run_result
|
||||
and item.route_node_state.node_run_result.outputs
|
||||
else ""
|
||||
)
|
||||
|
||||
self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs[
|
||||
"answer"
|
||||
].strip()
|
||||
except Exception as e:
|
||||
logger.exception(f"Graph run failed: {str(e)}")
|
||||
yield GraphRunFailedEvent(error=str(e))
|
||||
@@ -186,12 +189,12 @@ class GraphEngine:
|
||||
del GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id]
|
||||
|
||||
def _run(
|
||||
self,
|
||||
start_node_id: str,
|
||||
in_parallel_id: Optional[str] = None,
|
||||
parent_parallel_id: Optional[str] = None,
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
self,
|
||||
start_node_id: str,
|
||||
in_parallel_id: Optional[str] = None,
|
||||
parent_parallel_id: Optional[str] = None,
|
||||
parent_parallel_start_node_id: Optional[str] = None,
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
parallel_start_node_id = None
|
||||
if in_parallel_id:
|
||||
parallel_start_node_id = start_node_id
|
||||
@@ -201,31 +204,28 @@ class GraphEngine:
|
||||
while True:
|
||||
# max steps reached
|
||||
if self.graph_runtime_state.node_run_steps > self.max_execution_steps:
|
||||
raise GraphRunFailedError('Max steps {} reached.'.format(self.max_execution_steps))
|
||||
raise GraphRunFailedError("Max steps {} reached.".format(self.max_execution_steps))
|
||||
|
||||
# or max execution time reached
|
||||
if self._is_timed_out(
|
||||
start_at=self.graph_runtime_state.start_at,
|
||||
max_execution_time=self.max_execution_time
|
||||
start_at=self.graph_runtime_state.start_at, max_execution_time=self.max_execution_time
|
||||
):
|
||||
raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time))
|
||||
raise GraphRunFailedError("Max execution time {}s reached.".format(self.max_execution_time))
|
||||
|
||||
# init route node state
|
||||
route_node_state = self.graph_runtime_state.node_run_state.create_node_state(
|
||||
node_id=next_node_id
|
||||
)
|
||||
route_node_state = self.graph_runtime_state.node_run_state.create_node_state(node_id=next_node_id)
|
||||
|
||||
# get node config
|
||||
node_id = route_node_state.node_id
|
||||
node_config = self.graph.node_id_config_mapping.get(node_id)
|
||||
if not node_config:
|
||||
raise GraphRunFailedError(f'Node {node_id} config not found.')
|
||||
raise GraphRunFailedError(f"Node {node_id} config not found.")
|
||||
|
||||
# convert to specific node
|
||||
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
||||
node_type = NodeType.value_of(node_config.get("data", {}).get("type"))
|
||||
node_cls = node_classes.get(node_type)
|
||||
if not node_cls:
|
||||
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
|
||||
raise GraphRunFailedError(f"Node {node_id} type {node_type} not found.")
|
||||
|
||||
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
|
||||
|
||||
@@ -237,7 +237,7 @@ class GraphEngine:
|
||||
graph=self.graph,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
previous_node_id=previous_node_id,
|
||||
thread_pool_id=self.thread_pool_id
|
||||
thread_pool_id=self.thread_pool_id,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -248,7 +248,7 @@ class GraphEngine:
|
||||
parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
|
||||
for item in generator:
|
||||
@@ -263,8 +263,7 @@ class GraphEngine:
|
||||
# append route
|
||||
if previous_route_node_state:
|
||||
self.graph_runtime_state.node_run_state.add_route(
|
||||
source_node_state_id=previous_route_node_state.id,
|
||||
target_node_state_id=route_node_state.id
|
||||
source_node_state_id=previous_route_node_state.id, target_node_state_id=route_node_state.id
|
||||
)
|
||||
except Exception as e:
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
@@ -279,13 +278,15 @@ class GraphEngine:
|
||||
parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
raise e
|
||||
|
||||
# It may not be necessary, but it is necessary. :)
|
||||
if (self.graph.node_id_config_mapping[next_node_id]
|
||||
.get("data", {}).get("type", "").lower() == NodeType.END.value):
|
||||
if (
|
||||
self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower()
|
||||
== NodeType.END.value
|
||||
):
|
||||
break
|
||||
|
||||
previous_route_node_state = route_node_state
|
||||
@@ -305,7 +306,7 @@ class GraphEngine:
|
||||
run_condition=edge.run_condition,
|
||||
).check(
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
previous_route_node_state=previous_route_node_state
|
||||
previous_route_node_state=previous_route_node_state,
|
||||
)
|
||||
|
||||
if not result:
|
||||
@@ -343,14 +344,14 @@ class GraphEngine:
|
||||
|
||||
if not result:
|
||||
continue
|
||||
|
||||
|
||||
if len(sub_edge_mappings) == 1:
|
||||
final_node_id = edge.target_node_id
|
||||
else:
|
||||
parallel_generator = self._run_parallel_branches(
|
||||
edge_mappings=sub_edge_mappings,
|
||||
in_parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
)
|
||||
|
||||
for item in parallel_generator:
|
||||
@@ -369,7 +370,7 @@ class GraphEngine:
|
||||
parallel_generator = self._run_parallel_branches(
|
||||
edge_mappings=edge_mappings,
|
||||
in_parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
)
|
||||
|
||||
for item in parallel_generator:
|
||||
@@ -383,14 +384,14 @@ class GraphEngine:
|
||||
|
||||
next_node_id = final_node_id
|
||||
|
||||
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
|
||||
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, "") != in_parallel_id:
|
||||
break
|
||||
|
||||
def _run_parallel_branches(
|
||||
self,
|
||||
edge_mappings: list[GraphEdge],
|
||||
in_parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None,
|
||||
self,
|
||||
edge_mappings: list[GraphEdge],
|
||||
in_parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None,
|
||||
) -> Generator[GraphEngineEvent | str, None, None]:
|
||||
# if nodes has no run conditions, parallel run all nodes
|
||||
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
|
||||
@@ -398,14 +399,18 @@ class GraphEngine:
|
||||
node_id = edge_mappings[0].target_node_id
|
||||
node_config = self.graph.node_id_config_mapping.get(node_id)
|
||||
if not node_config:
|
||||
raise GraphRunFailedError(f'Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches.')
|
||||
raise GraphRunFailedError(
|
||||
f"Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches."
|
||||
)
|
||||
|
||||
node_title = node_config.get('data', {}).get('title')
|
||||
raise GraphRunFailedError(f'Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches.')
|
||||
node_title = node_config.get("data", {}).get("title")
|
||||
raise GraphRunFailedError(
|
||||
f"Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches."
|
||||
)
|
||||
|
||||
parallel = self.graph.parallel_mapping.get(parallel_id)
|
||||
if not parallel:
|
||||
raise GraphRunFailedError(f'Parallel {parallel_id} not found.')
|
||||
raise GraphRunFailedError(f"Parallel {parallel_id} not found.")
|
||||
|
||||
# run parallel nodes, run in new thread and use queue to get results
|
||||
q: queue.Queue = queue.Queue()
|
||||
@@ -417,19 +422,22 @@ class GraphEngine:
|
||||
for edge in edge_mappings:
|
||||
if (
|
||||
edge.target_node_id not in self.graph.node_parallel_mapping
|
||||
or self.graph.node_parallel_mapping.get(edge.target_node_id, '') != parallel_id
|
||||
or self.graph.node_parallel_mapping.get(edge.target_node_id, "") != parallel_id
|
||||
):
|
||||
continue
|
||||
|
||||
futures.append(
|
||||
self.thread_pool.submit(self._run_parallel_node, **{
|
||||
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
|
||||
'q': q,
|
||||
'parallel_id': parallel_id,
|
||||
'parallel_start_node_id': edge.target_node_id,
|
||||
'parent_parallel_id': in_parallel_id,
|
||||
'parent_parallel_start_node_id': parallel_start_node_id,
|
||||
})
|
||||
self.thread_pool.submit(
|
||||
self._run_parallel_node,
|
||||
**{
|
||||
"flask_app": current_app._get_current_object(), # type: ignore[attr-defined]
|
||||
"q": q,
|
||||
"parallel_id": parallel_id,
|
||||
"parallel_start_node_id": edge.target_node_id,
|
||||
"parent_parallel_id": in_parallel_id,
|
||||
"parent_parallel_start_node_id": parallel_start_node_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
succeeded_count = 0
|
||||
@@ -451,7 +459,7 @@ class GraphEngine:
|
||||
raise GraphRunFailedError(event.error)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
|
||||
# wait all threads
|
||||
wait(futures)
|
||||
|
||||
@@ -461,72 +469,80 @@ class GraphEngine:
|
||||
yield final_node_id
|
||||
|
||||
def _run_parallel_node(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
q: queue.Queue,
|
||||
parallel_id: str,
|
||||
parallel_start_node_id: str,
|
||||
parent_parallel_id: Optional[str] = None,
|
||||
parent_parallel_start_node_id: Optional[str] = None,
|
||||
self,
|
||||
flask_app: Flask,
|
||||
q: queue.Queue,
|
||||
parallel_id: str,
|
||||
parallel_start_node_id: str,
|
||||
parent_parallel_id: Optional[str] = None,
|
||||
parent_parallel_start_node_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run parallel nodes
|
||||
"""
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
q.put(ParallelBranchRunStartedEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
))
|
||||
q.put(
|
||||
ParallelBranchRunStartedEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
)
|
||||
|
||||
# run node
|
||||
generator = self._run(
|
||||
start_node_id=parallel_start_node_id,
|
||||
in_parallel_id=parallel_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
|
||||
for item in generator:
|
||||
q.put(item)
|
||||
|
||||
# trigger graph run success event
|
||||
q.put(ParallelBranchRunSucceededEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
))
|
||||
q.put(
|
||||
ParallelBranchRunSucceededEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
)
|
||||
except GraphRunFailedError as e:
|
||||
q.put(ParallelBranchRunFailedEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
error=e.error
|
||||
))
|
||||
q.put(
|
||||
ParallelBranchRunFailedEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
error=e.error,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating in parallel")
|
||||
q.put(ParallelBranchRunFailedEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
error=str(e)
|
||||
))
|
||||
q.put(
|
||||
ParallelBranchRunFailedEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
def _run_node(
|
||||
self,
|
||||
node_instance: BaseNode,
|
||||
route_node_state: RouteNodeState,
|
||||
parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None,
|
||||
parent_parallel_id: Optional[str] = None,
|
||||
parent_parallel_start_node_id: Optional[str] = None,
|
||||
self,
|
||||
node_instance: BaseNode,
|
||||
route_node_state: RouteNodeState,
|
||||
parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None,
|
||||
parent_parallel_id: Optional[str] = None,
|
||||
parent_parallel_start_node_id: Optional[str] = None,
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Run node
|
||||
@@ -542,7 +558,7 @@ class GraphEngine:
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
@@ -567,7 +583,7 @@ class GraphEngine:
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
yield NodeRunFailedEvent(
|
||||
error=route_node_state.failed_reason or 'Unknown error.',
|
||||
error=route_node_state.failed_reason or "Unknown error.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
@@ -576,7 +592,7 @@ class GraphEngine:
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
@@ -596,7 +612,7 @@ class GraphEngine:
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value
|
||||
variable_value=variable_value,
|
||||
)
|
||||
|
||||
# add parallel info to run result metadata
|
||||
@@ -608,7 +624,9 @@ class GraphEngine:
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
|
||||
if parent_parallel_id and parent_parallel_start_node_id:
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = parent_parallel_start_node_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
|
||||
parent_parallel_start_node_id
|
||||
)
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
id=node_instance.id,
|
||||
@@ -619,7 +637,7 @@ class GraphEngine:
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
|
||||
break
|
||||
@@ -635,7 +653,7 @@ class GraphEngine:
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
elif isinstance(item, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
@@ -649,7 +667,7 @@ class GraphEngine:
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
# trigger node run failed event
|
||||
@@ -665,7 +683,7 @@ class GraphEngine:
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
@@ -674,10 +692,7 @@ class GraphEngine:
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _append_variables_recursively(self,
|
||||
node_id: str,
|
||||
variable_key_list: list[str],
|
||||
variable_value: VariableValue):
|
||||
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
||||
"""
|
||||
Append variables recursively
|
||||
:param node_id: node id
|
||||
@@ -685,10 +700,7 @@ class GraphEngine:
|
||||
:param variable_value: variable value
|
||||
:return:
|
||||
"""
|
||||
self.graph_runtime_state.variable_pool.add(
|
||||
[node_id] + variable_key_list,
|
||||
variable_value
|
||||
)
|
||||
self.graph_runtime_state.variable_pool.add([node_id] + variable_key_list, variable_value)
|
||||
|
||||
# if variable_value is a dict, then recursively append variables
|
||||
if isinstance(variable_value, dict):
|
||||
@@ -696,9 +708,7 @@ class GraphEngine:
|
||||
# construct new key list
|
||||
new_key_list = variable_key_list + [key]
|
||||
self._append_variables_recursively(
|
||||
node_id=node_id,
|
||||
variable_key_list=new_key_list,
|
||||
variable_value=value
|
||||
node_id=node_id, variable_key_list=new_key_list, variable_value=value
|
||||
)
|
||||
|
||||
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
|
||||
|
Reference in New Issue
Block a user