chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -8,19 +8,13 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta
class RunConditionHandler(ABC):
def __init__(self,
init_params: GraphInitParams,
graph: Graph,
condition: RunCondition):
def __init__(self, init_params: GraphInitParams, graph: Graph, condition: RunCondition):
self.init_params = init_params
self.graph = graph
self.condition = condition
@abstractmethod
def check(self,
graph_runtime_state: GraphRuntimeState,
previous_route_node_state: RouteNodeState
) -> bool:
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
"""
Check if the condition can be executed

View File

@@ -4,10 +4,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta
class BranchIdentifyRunConditionHandler(RunConditionHandler):
def check(self,
graph_runtime_state: GraphRuntimeState,
previous_route_node_state: RouteNodeState) -> bool:
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
"""
Check if the condition can be executed

View File

@@ -5,10 +5,7 @@ from core.workflow.utils.condition.processor import ConditionProcessor
class ConditionRunConditionHandlerHandler(RunConditionHandler):
def check(self,
graph_runtime_state: GraphRuntimeState,
previous_route_node_state: RouteNodeState
) -> bool:
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
"""
Check if the condition can be executed
@@ -22,8 +19,7 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler):
# process condition
condition_processor = ConditionProcessor()
input_conditions, group_result = condition_processor.process_conditions(
variable_pool=graph_runtime_state.variable_pool,
conditions=self.condition.conditions
variable_pool=graph_runtime_state.variable_pool, conditions=self.condition.conditions
)
# Apply the logical operator for the current case

View File

@@ -9,9 +9,7 @@ from core.workflow.graph_engine.entities.run_condition import RunCondition
class ConditionManager:
@staticmethod
def get_condition_handler(
init_params: GraphInitParams,
graph: Graph,
run_condition: RunCondition
init_params: GraphInitParams, graph: Graph, run_condition: RunCondition
) -> RunConditionHandler:
"""
Get condition handler
@@ -22,14 +20,6 @@ class ConditionManager:
:return: condition handler
"""
if run_condition.type == "branch_identify":
return BranchIdentifyRunConditionHandler(
init_params=init_params,
graph=graph,
condition=run_condition
)
return BranchIdentifyRunConditionHandler(init_params=init_params, graph=graph, condition=run_condition)
else:
return ConditionRunConditionHandlerHandler(
init_params=init_params,
graph=graph,
condition=run_condition
)
return ConditionRunConditionHandlerHandler(init_params=init_params, graph=graph, condition=run_condition)

View File

@@ -34,38 +34,25 @@ class Graph(BaseModel):
root_node_id: str = Field(..., description="root node id of the graph")
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
node_id_config_mapping: dict[str, dict] = Field(
default_factory=list,
description="node configs mapping (node id: node config)"
default_factory=list, description="node configs mapping (node id: node config)"
)
edge_mapping: dict[str, list[GraphEdge]] = Field(
default_factory=dict,
description="graph edge mapping (source node id: edges)"
default_factory=dict, description="graph edge mapping (source node id: edges)"
)
reverse_edge_mapping: dict[str, list[GraphEdge]] = Field(
default_factory=dict,
description="reverse graph edge mapping (target node id: edges)"
default_factory=dict, description="reverse graph edge mapping (target node id: edges)"
)
parallel_mapping: dict[str, GraphParallel] = Field(
default_factory=dict,
description="graph parallel mapping (parallel id: parallel)"
default_factory=dict, description="graph parallel mapping (parallel id: parallel)"
)
node_parallel_mapping: dict[str, str] = Field(
default_factory=dict,
description="graph node parallel mapping (node id: parallel id)"
)
answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(
...,
description="answer stream generate routes"
)
end_stream_param: EndStreamParam = Field(
...,
description="end stream param"
default_factory=dict, description="graph node parallel mapping (node id: parallel id)"
)
answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(..., description="answer stream generate routes")
end_stream_param: EndStreamParam = Field(..., description="end stream param")
@classmethod
def init(cls,
graph_config: Mapping[str, Any],
root_node_id: Optional[str] = None) -> "Graph":
def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = None) -> "Graph":
"""
Init graph
@@ -74,7 +61,7 @@ class Graph(BaseModel):
:return: graph
"""
# edge configs
edge_configs = graph_config.get('edges')
edge_configs = graph_config.get("edges")
if edge_configs is None:
edge_configs = []
@@ -85,14 +72,14 @@ class Graph(BaseModel):
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
target_edge_ids = set()
for edge_config in edge_configs:
source_node_id = edge_config.get('source')
source_node_id = edge_config.get("source")
if not source_node_id:
continue
if source_node_id not in edge_mapping:
edge_mapping[source_node_id] = []
target_node_id = edge_config.get('target')
target_node_id = edge_config.get("target")
if not target_node_id:
continue
@@ -107,23 +94,18 @@ class Graph(BaseModel):
# parse run condition
run_condition = None
if edge_config.get('sourceHandle') and edge_config.get('sourceHandle') != 'source':
run_condition = RunCondition(
type='branch_identify',
branch_identify=edge_config.get('sourceHandle')
)
if edge_config.get("sourceHandle") and edge_config.get("sourceHandle") != "source":
run_condition = RunCondition(type="branch_identify", branch_identify=edge_config.get("sourceHandle"))
graph_edge = GraphEdge(
source_node_id=source_node_id,
target_node_id=target_node_id,
run_condition=run_condition
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition
)
edge_mapping[source_node_id].append(graph_edge)
reverse_edge_mapping[target_node_id].append(graph_edge)
# node configs
node_configs = graph_config.get('nodes')
node_configs = graph_config.get("nodes")
if not node_configs:
raise ValueError("Graph must have at least one node")
@@ -133,7 +115,7 @@ class Graph(BaseModel):
root_node_configs = []
all_node_id_config_mapping: dict[str, dict] = {}
for node_config in node_configs:
node_id = node_config.get('id')
node_id = node_config.get("id")
if not node_id:
continue
@@ -142,30 +124,29 @@ class Graph(BaseModel):
all_node_id_config_mapping[node_id] = node_config
root_node_ids = [node_config.get('id') for node_config in root_node_configs]
root_node_ids = [node_config.get("id") for node_config in root_node_configs]
# fetch root node
if not root_node_id:
# if no root node id, use the START type node as root node
root_node_id = next((node_config.get("id") for node_config in root_node_configs
if node_config.get('data', {}).get('type', '') == NodeType.START.value), None)
root_node_id = next(
(
node_config.get("id")
for node_config in root_node_configs
if node_config.get("data", {}).get("type", "") == NodeType.START.value
),
None,
)
if not root_node_id or root_node_id not in root_node_ids:
raise ValueError(f"Root node id {root_node_id} not found in the graph")
# Check whether it is connected to the previous node
cls._check_connected_to_previous_node(
route=[root_node_id],
edge_mapping=edge_mapping
)
cls._check_connected_to_previous_node(route=[root_node_id], edge_mapping=edge_mapping)
# fetch all node ids from root node
node_ids = [root_node_id]
cls._recursively_add_node_ids(
node_ids=node_ids,
edge_mapping=edge_mapping,
node_id=root_node_id
)
cls._recursively_add_node_ids(node_ids=node_ids, edge_mapping=edge_mapping, node_id=root_node_id)
node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids}
@@ -177,29 +158,26 @@ class Graph(BaseModel):
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=root_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping
node_parallel_mapping=node_parallel_mapping,
)
# Check if it exceeds N layers of parallel
for parallel in parallel_mapping.values():
if parallel.parent_parallel_id:
cls._check_exceed_parallel_limit(
parallel_mapping=parallel_mapping,
level_limit=3,
parent_parallel_id=parallel.parent_parallel_id
parallel_mapping=parallel_mapping, level_limit=3, parent_parallel_id=parallel.parent_parallel_id
)
# init answer stream generate routes
answer_stream_generate_routes = AnswerStreamGeneratorRouter.init(
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping
node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping
)
# init end stream param
end_stream_param = EndStreamGeneratorRouter.init(
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
node_parallel_mapping=node_parallel_mapping
node_parallel_mapping=node_parallel_mapping,
)
# init graph
@@ -212,14 +190,14 @@ class Graph(BaseModel):
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
answer_stream_generate_routes=answer_stream_generate_routes,
end_stream_param=end_stream_param
end_stream_param=end_stream_param,
)
return graph
def add_extra_edge(self, source_node_id: str,
target_node_id: str,
run_condition: Optional[RunCondition] = None) -> None:
def add_extra_edge(
self, source_node_id: str, target_node_id: str, run_condition: Optional[RunCondition] = None
) -> None:
"""
Add extra edge to the graph
@@ -237,9 +215,7 @@ class Graph(BaseModel):
return
graph_edge = GraphEdge(
source_node_id=source_node_id,
target_node_id=target_node_id,
run_condition=run_condition
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition
)
self.edge_mapping[source_node_id].append(graph_edge)
@@ -254,17 +230,18 @@ class Graph(BaseModel):
for node_id in self.node_ids:
if node_id not in self.edge_mapping:
leaf_node_ids.append(node_id)
elif (len(self.edge_mapping[node_id]) == 1
and self.edge_mapping[node_id][0].target_node_id == self.root_node_id):
elif (
len(self.edge_mapping[node_id]) == 1
and self.edge_mapping[node_id][0].target_node_id == self.root_node_id
):
leaf_node_ids.append(node_id)
return leaf_node_ids
@classmethod
def _recursively_add_node_ids(cls,
node_ids: list[str],
edge_mapping: dict[str, list[GraphEdge]],
node_id: str) -> None:
def _recursively_add_node_ids(
cls, node_ids: list[str], edge_mapping: dict[str, list[GraphEdge]], node_id: str
) -> None:
"""
Recursively add node ids
@@ -278,17 +255,11 @@ class Graph(BaseModel):
node_ids.append(graph_edge.target_node_id)
cls._recursively_add_node_ids(
node_ids=node_ids,
edge_mapping=edge_mapping,
node_id=graph_edge.target_node_id
node_ids=node_ids, edge_mapping=edge_mapping, node_id=graph_edge.target_node_id
)
@classmethod
def _check_connected_to_previous_node(
cls,
route: list[str],
edge_mapping: dict[str, list[GraphEdge]]
) -> None:
def _check_connected_to_previous_node(cls, route: list[str], edge_mapping: dict[str, list[GraphEdge]]) -> None:
"""
Check whether it is connected to the previous node
"""
@@ -299,7 +270,9 @@ class Graph(BaseModel):
continue
if graph_edge.target_node_id in route:
raise ValueError(f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph.")
raise ValueError(
f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph."
)
new_route = route[:]
new_route.append(graph_edge.target_node_id)
@@ -316,7 +289,7 @@ class Graph(BaseModel):
start_node_id: str,
parallel_mapping: dict[str, GraphParallel],
node_parallel_mapping: dict[str, str],
parent_parallel: Optional[GraphParallel] = None
parent_parallel: Optional[GraphParallel] = None,
) -> None:
"""
Recursively add parallel ids
@@ -355,14 +328,14 @@ class Graph(BaseModel):
parallel = GraphParallel(
start_from_node_id=start_node_id,
parent_parallel_id=parent_parallel.id if parent_parallel else None,
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
)
parallel_mapping[parallel.id] = parallel
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
parallel_branch_node_ids=parallel_branch_node_ids
parallel_branch_node_ids=parallel_branch_node_ids,
)
# collect all branches node ids
@@ -403,14 +376,25 @@ class Graph(BaseModel):
continue
if (
(node_parallel_mapping.get(target_node_id) and node_parallel_mapping.get(target_node_id) == parent_parallel_id)
or (parent_parallel and parent_parallel.end_to_node_id and target_node_id == parent_parallel.end_to_node_id)
(
node_parallel_mapping.get(target_node_id)
and node_parallel_mapping.get(target_node_id) == parent_parallel_id
)
or (
parent_parallel
and parent_parallel.end_to_node_id
and target_node_id == parent_parallel.end_to_node_id
)
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
):
outside_parallel_target_node_ids.add(target_node_id)
if len(outside_parallel_target_node_ids) == 1:
if parent_parallel and parent_parallel.end_to_node_id and parallel.end_to_node_id == parent_parallel.end_to_node_id:
if (
parent_parallel
and parent_parallel.end_to_node_id
and parallel.end_to_node_id == parent_parallel.end_to_node_id
):
parallel.end_to_node_id = None
else:
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
@@ -420,18 +404,20 @@ class Graph(BaseModel):
if parallel:
current_parallel = parallel
elif parent_parallel:
if not parent_parallel.end_to_node_id or (parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id):
if not parent_parallel.end_to_node_id or (
parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id
):
current_parallel = parent_parallel
else:
# fetch parent parallel's parent parallel
parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id
if parent_parallel_parent_parallel_id:
parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id)
if (
parent_parallel_parent_parallel
and (
not parent_parallel_parent_parallel.end_to_node_id
or (parent_parallel_parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id)
if parent_parallel_parent_parallel and (
not parent_parallel_parent_parallel.end_to_node_id
or (
parent_parallel_parent_parallel.end_to_node_id
and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id
)
):
current_parallel = parent_parallel_parent_parallel
@@ -442,7 +428,7 @@ class Graph(BaseModel):
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel
parent_parallel=current_parallel,
)
@classmethod
@@ -451,7 +437,7 @@ class Graph(BaseModel):
parallel_mapping: dict[str, GraphParallel],
level_limit: int,
parent_parallel_id: str,
current_level: int = 1
current_level: int = 1,
) -> None:
"""
Check if it exceeds N layers of parallel
@@ -459,25 +445,27 @@ class Graph(BaseModel):
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
return
current_level += 1
if current_level > level_limit:
raise ValueError(f"Exceeds {level_limit} layers of parallel")
if parent_parallel.parent_parallel_id:
cls._check_exceed_parallel_limit(
parallel_mapping=parallel_mapping,
level_limit=level_limit,
parent_parallel_id=parent_parallel.parent_parallel_id,
current_level=current_level
current_level=current_level,
)
@classmethod
def _recursively_add_parallel_node_ids(cls,
branch_node_ids: list[str],
edge_mapping: dict[str, list[GraphEdge]],
merge_node_id: str,
start_node_id: str) -> None:
def _recursively_add_parallel_node_ids(
cls,
branch_node_ids: list[str],
edge_mapping: dict[str, list[GraphEdge]],
merge_node_id: str,
start_node_id: str,
) -> None:
"""
Recursively add node ids
@@ -487,21 +475,22 @@ class Graph(BaseModel):
:param start_node_id: start node id
"""
for graph_edge in edge_mapping.get(start_node_id, []):
if (graph_edge.target_node_id != merge_node_id
and graph_edge.target_node_id not in branch_node_ids):
if graph_edge.target_node_id != merge_node_id and graph_edge.target_node_id not in branch_node_ids:
branch_node_ids.append(graph_edge.target_node_id)
cls._recursively_add_parallel_node_ids(
branch_node_ids=branch_node_ids,
edge_mapping=edge_mapping,
merge_node_id=merge_node_id,
start_node_id=graph_edge.target_node_id
start_node_id=graph_edge.target_node_id,
)
@classmethod
def _fetch_all_node_ids_in_parallels(cls,
edge_mapping: dict[str, list[GraphEdge]],
reverse_edge_mapping: dict[str, list[GraphEdge]],
parallel_branch_node_ids: list[str]) -> dict[str, list[str]]:
def _fetch_all_node_ids_in_parallels(
cls,
edge_mapping: dict[str, list[GraphEdge]],
reverse_edge_mapping: dict[str, list[GraphEdge]],
parallel_branch_node_ids: list[str],
) -> dict[str, list[str]]:
"""
Fetch all node ids in parallels
"""
@@ -513,7 +502,7 @@ class Graph(BaseModel):
cls._recursively_fetch_routes(
edge_mapping=edge_mapping,
start_node_id=parallel_branch_node_id,
routes_node_ids=routes_node_ids[parallel_branch_node_id]
routes_node_ids=routes_node_ids[parallel_branch_node_id],
)
# fetch leaf node ids from routes node ids
@@ -529,13 +518,13 @@ class Graph(BaseModel):
for branch_node_id2, inner_route2 in routes_node_ids.items():
if (
branch_node_id != branch_node_id2
branch_node_id != branch_node_id2
and node_id in inner_route2
and len(reverse_edge_mapping.get(node_id, [])) > 1
and cls._is_node_in_routes(
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=node_id,
routes_node_ids=routes_node_ids
routes_node_ids=routes_node_ids,
)
):
if node_id not in merge_branch_node_ids:
@@ -551,23 +540,18 @@ class Graph(BaseModel):
for node_id, branch_node_ids in merge_branch_node_ids.items():
for node_id2, branch_node_ids2 in merge_branch_node_ids.items():
if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2):
if (node_id, node_id2) not in duplicate_end_node_ids and (node_id2, node_id) not in duplicate_end_node_ids:
if (node_id, node_id2) not in duplicate_end_node_ids and (
node_id2,
node_id,
) not in duplicate_end_node_ids:
duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids
for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
# check which node is after
if cls._is_node2_after_node1(
node1_id=node_id,
node2_id=node_id2,
edge_mapping=edge_mapping
):
if cls._is_node2_after_node1(node1_id=node_id, node2_id=node_id2, edge_mapping=edge_mapping):
if node_id in merge_branch_node_ids:
del merge_branch_node_ids[node_id2]
elif cls._is_node2_after_node1(
node1_id=node_id2,
node2_id=node_id,
edge_mapping=edge_mapping
):
elif cls._is_node2_after_node1(node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping):
if node_id2 in merge_branch_node_ids:
del merge_branch_node_ids[node_id]
@@ -599,16 +583,15 @@ class Graph(BaseModel):
branch_node_ids=in_branch_node_ids[branch_node_id],
edge_mapping=edge_mapping,
merge_node_id=merge_node_id,
start_node_id=branch_node_id
start_node_id=branch_node_id,
)
return in_branch_node_ids
@classmethod
def _recursively_fetch_routes(cls,
edge_mapping: dict[str, list[GraphEdge]],
start_node_id: str,
routes_node_ids: list[str]) -> None:
def _recursively_fetch_routes(
cls, edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: list[str]
) -> None:
"""
Recursively fetch route
"""
@@ -621,28 +604,25 @@ class Graph(BaseModel):
routes_node_ids.append(graph_edge.target_node_id)
cls._recursively_fetch_routes(
edge_mapping=edge_mapping,
start_node_id=graph_edge.target_node_id,
routes_node_ids=routes_node_ids
edge_mapping=edge_mapping, start_node_id=graph_edge.target_node_id, routes_node_ids=routes_node_ids
)
@classmethod
def _is_node_in_routes(cls,
reverse_edge_mapping: dict[str, list[GraphEdge]],
start_node_id: str,
routes_node_ids: dict[str, list[str]]) -> bool:
def _is_node_in_routes(
cls, reverse_edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: dict[str, list[str]]
) -> bool:
"""
Recursively check if the node is in the routes
"""
if start_node_id not in reverse_edge_mapping:
return False
all_routes_node_ids = set()
parallel_start_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
for node_id in node_ids:
all_routes_node_ids.add(node_id)
if branch_node_id in reverse_edge_mapping:
for graph_edge in reverse_edge_mapping[branch_node_id]:
if graph_edge.source_node_id not in parallel_start_node_ids:
@@ -655,38 +635,34 @@ class Graph(BaseModel):
if set(branch_node_ids) == set(routes_node_ids.keys()):
parallel_start_node_id = p_start_node_id
return True
if not parallel_start_node_id:
raise Exception("Parallel start node id not found")
for graph_edge in reverse_edge_mapping[start_node_id]:
if graph_edge.source_node_id not in all_routes_node_ids or graph_edge.source_node_id != parallel_start_node_id:
if (
graph_edge.source_node_id not in all_routes_node_ids
or graph_edge.source_node_id != parallel_start_node_id
):
return False
return True
@classmethod
def _is_node2_after_node1(
cls,
node1_id: str,
node2_id: str,
edge_mapping: dict[str, list[GraphEdge]]
) -> bool:
def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool:
"""
is node2 after node1
"""
if node1_id not in edge_mapping:
return False
for graph_edge in edge_mapping[node1_id]:
if graph_edge.target_node_id == node2_id:
return True
if cls._is_node2_after_node1(
node1_id=graph_edge.target_node_id,
node2_id=node2_id,
edge_mapping=edge_mapping
node1_id=graph_edge.target_node_id, node2_id=node2_id, edge_mapping=edge_mapping
):
return True
return False
return False

View File

@@ -10,7 +10,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRoute
class GraphRuntimeState(BaseModel):
variable_pool: VariablePool = Field(..., description="variable pool")
"""variable pool"""
start_at: float = Field(..., description="start time")
"""start time"""
total_tokens: int = 0

View File

@@ -18,4 +18,4 @@ class RunCondition(BaseModel):
@property
def hash(self) -> str:
return hashlib.sha256(self.model_dump_json().encode()).hexdigest()
return hashlib.sha256(self.model_dump_json().encode()).hexdigest()

View File

@@ -68,13 +68,11 @@ class RouteNodeState(BaseModel):
class RuntimeRouteState(BaseModel):
routes: dict[str, list[str]] = Field(
default_factory=dict,
description="graph state routes (source_node_state_id: target_node_state_id)"
default_factory=dict, description="graph state routes (source_node_state_id: target_node_state_id)"
)
node_state_mapping: dict[str, RouteNodeState] = Field(
default_factory=dict,
description="node state mapping (route_node_state_id: route_node_state)"
default_factory=dict, description="node state mapping (route_node_state_id: route_node_state)"
)
def create_node_state(self, node_id: str) -> RouteNodeState:
@@ -99,13 +97,13 @@ class RuntimeRouteState(BaseModel):
self.routes[source_node_state_id].append(target_node_state_id)
def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) \
-> list[RouteNodeState]:
def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) -> list[RouteNodeState]:
"""
Get routes with node state by source node id
:param source_node_state_id: source node state id
:return: routes with node state
"""
return [self.node_state_mapping[target_state_id]
for target_state_id in self.routes.get(source_node_state_id, [])]
return [
self.node_state_mapping[target_state_id] for target_state_id in self.routes.get(source_node_state_id, [])
]

View File

@@ -48,8 +48,9 @@ logger = logging.getLogger(__name__)
class GraphEngineThreadPool(ThreadPoolExecutor):
def __init__(self, max_workers=None, thread_name_prefix='',
initializer=None, initargs=(), max_submit_count=100) -> None:
def __init__(
self, max_workers=None, thread_name_prefix="", initializer=None, initargs=(), max_submit_count=100
) -> None:
super().__init__(max_workers, thread_name_prefix, initializer, initargs)
self.max_submit_count = max_submit_count
self.submit_count = 0
@@ -57,9 +58,9 @@ class GraphEngineThreadPool(ThreadPoolExecutor):
def submit(self, fn, *args, **kwargs):
self.submit_count += 1
self.check_is_full()
return super().submit(fn, *args, **kwargs)
def check_is_full(self) -> None:
print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}")
if self.submit_count > self.max_submit_count:
@@ -70,21 +71,21 @@ class GraphEngine:
workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {}
def __init__(
self,
tenant_id: str,
app_id: str,
workflow_type: WorkflowType,
workflow_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
call_depth: int,
graph: Graph,
graph_config: Mapping[str, Any],
variable_pool: VariablePool,
max_execution_steps: int,
max_execution_time: int,
thread_pool_id: Optional[str] = None
self,
tenant_id: str,
app_id: str,
workflow_type: WorkflowType,
workflow_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
call_depth: int,
graph: Graph,
graph_config: Mapping[str, Any],
variable_pool: VariablePool,
max_execution_steps: int,
max_execution_time: int,
thread_pool_id: Optional[str] = None,
) -> None:
thread_pool_max_submit_count = 100
thread_pool_max_workers = 10
@@ -93,12 +94,14 @@ class GraphEngine:
if thread_pool_id:
if not thread_pool_id in GraphEngine.workflow_thread_pool_mapping:
raise ValueError(f"Max submit count {thread_pool_max_submit_count} of workflow thread pool reached.")
self.thread_pool_id = thread_pool_id
self.thread_pool = GraphEngine.workflow_thread_pool_mapping[thread_pool_id]
self.is_main_thread_pool = False
else:
self.thread_pool = GraphEngineThreadPool(max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count)
self.thread_pool = GraphEngineThreadPool(
max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count
)
self.thread_pool_id = str(uuid.uuid4())
self.is_main_thread_pool = True
GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool
@@ -113,13 +116,10 @@ class GraphEngine:
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
call_depth=call_depth
call_depth=call_depth,
)
self.graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter()
)
self.graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
self.max_execution_steps = max_execution_steps
self.max_execution_time = max_execution_time
@@ -136,37 +136,40 @@ class GraphEngine:
stream_processor_cls = EndStreamProcessor
stream_processor = stream_processor_cls(
graph=self.graph,
variable_pool=self.graph_runtime_state.variable_pool
graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool
)
# run graph
generator = stream_processor.process(
self._run(start_node_id=self.graph.root_node_id)
)
generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id))
for item in generator:
try:
yield item
if isinstance(item, NodeRunFailedEvent):
yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or 'Unknown error.')
yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or "Unknown error.")
return
elif isinstance(item, NodeRunSucceededEvent):
if item.node_type == NodeType.END:
self.graph_runtime_state.outputs = (item.route_node_state.node_run_result.outputs
if item.route_node_state.node_run_result
and item.route_node_state.node_run_result.outputs
else {})
self.graph_runtime_state.outputs = (
item.route_node_state.node_run_result.outputs
if item.route_node_state.node_run_result
and item.route_node_state.node_run_result.outputs
else {}
)
elif item.node_type == NodeType.ANSWER:
if "answer" not in self.graph_runtime_state.outputs:
self.graph_runtime_state.outputs["answer"] = ""
self.graph_runtime_state.outputs["answer"] += "\n" + (item.route_node_state.node_run_result.outputs.get("answer", "")
if item.route_node_state.node_run_result
and item.route_node_state.node_run_result.outputs
else "")
self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs["answer"].strip()
self.graph_runtime_state.outputs["answer"] += "\n" + (
item.route_node_state.node_run_result.outputs.get("answer", "")
if item.route_node_state.node_run_result
and item.route_node_state.node_run_result.outputs
else ""
)
self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs[
"answer"
].strip()
except Exception as e:
logger.exception(f"Graph run failed: {str(e)}")
yield GraphRunFailedEvent(error=str(e))
@@ -186,12 +189,12 @@ class GraphEngine:
del GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id]
def _run(
self,
start_node_id: str,
in_parallel_id: Optional[str] = None,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None
) -> Generator[GraphEngineEvent, None, None]:
self,
start_node_id: str,
in_parallel_id: Optional[str] = None,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
) -> Generator[GraphEngineEvent, None, None]:
parallel_start_node_id = None
if in_parallel_id:
parallel_start_node_id = start_node_id
@@ -201,31 +204,28 @@ class GraphEngine:
while True:
# max steps reached
if self.graph_runtime_state.node_run_steps > self.max_execution_steps:
raise GraphRunFailedError('Max steps {} reached.'.format(self.max_execution_steps))
raise GraphRunFailedError("Max steps {} reached.".format(self.max_execution_steps))
# or max execution time reached
if self._is_timed_out(
start_at=self.graph_runtime_state.start_at,
max_execution_time=self.max_execution_time
start_at=self.graph_runtime_state.start_at, max_execution_time=self.max_execution_time
):
raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time))
raise GraphRunFailedError("Max execution time {}s reached.".format(self.max_execution_time))
# init route node state
route_node_state = self.graph_runtime_state.node_run_state.create_node_state(
node_id=next_node_id
)
route_node_state = self.graph_runtime_state.node_run_state.create_node_state(node_id=next_node_id)
# get node config
node_id = route_node_state.node_id
node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config:
raise GraphRunFailedError(f'Node {node_id} config not found.')
raise GraphRunFailedError(f"Node {node_id} config not found.")
# convert to specific node
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
node_type = NodeType.value_of(node_config.get("data", {}).get("type"))
node_cls = node_classes.get(node_type)
if not node_cls:
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
raise GraphRunFailedError(f"Node {node_id} type {node_type} not found.")
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
@@ -237,7 +237,7 @@ class GraphEngine:
graph=self.graph,
graph_runtime_state=self.graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=self.thread_pool_id
thread_pool_id=self.thread_pool_id,
)
try:
@@ -248,7 +248,7 @@ class GraphEngine:
parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
for item in generator:
@@ -263,8 +263,7 @@ class GraphEngine:
# append route
if previous_route_node_state:
self.graph_runtime_state.node_run_state.add_route(
source_node_state_id=previous_route_node_state.id,
target_node_state_id=route_node_state.id
source_node_state_id=previous_route_node_state.id, target_node_state_id=route_node_state.id
)
except Exception as e:
route_node_state.status = RouteNodeState.Status.FAILED
@@ -279,13 +278,15 @@ class GraphEngine:
parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
raise e
# It may not be necessary, but it is necessary. :)
if (self.graph.node_id_config_mapping[next_node_id]
.get("data", {}).get("type", "").lower() == NodeType.END.value):
if (
self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower()
== NodeType.END.value
):
break
previous_route_node_state = route_node_state
@@ -305,7 +306,7 @@ class GraphEngine:
run_condition=edge.run_condition,
).check(
graph_runtime_state=self.graph_runtime_state,
previous_route_node_state=previous_route_node_state
previous_route_node_state=previous_route_node_state,
)
if not result:
@@ -343,14 +344,14 @@ class GraphEngine:
if not result:
continue
if len(sub_edge_mappings) == 1:
final_node_id = edge.target_node_id
else:
parallel_generator = self._run_parallel_branches(
edge_mappings=sub_edge_mappings,
in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id
parallel_start_node_id=parallel_start_node_id,
)
for item in parallel_generator:
@@ -369,7 +370,7 @@ class GraphEngine:
parallel_generator = self._run_parallel_branches(
edge_mappings=edge_mappings,
in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id
parallel_start_node_id=parallel_start_node_id,
)
for item in parallel_generator:
@@ -383,14 +384,14 @@ class GraphEngine:
next_node_id = final_node_id
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, "") != in_parallel_id:
break
def _run_parallel_branches(
self,
edge_mappings: list[GraphEdge],
in_parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
self,
edge_mappings: list[GraphEdge],
in_parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
) -> Generator[GraphEngineEvent | str, None, None]:
# if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
@@ -398,14 +399,18 @@ class GraphEngine:
node_id = edge_mappings[0].target_node_id
node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config:
raise GraphRunFailedError(f'Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches.')
raise GraphRunFailedError(
f"Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches."
)
node_title = node_config.get('data', {}).get('title')
raise GraphRunFailedError(f'Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches.')
node_title = node_config.get("data", {}).get("title")
raise GraphRunFailedError(
f"Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches."
)
parallel = self.graph.parallel_mapping.get(parallel_id)
if not parallel:
raise GraphRunFailedError(f'Parallel {parallel_id} not found.')
raise GraphRunFailedError(f"Parallel {parallel_id} not found.")
# run parallel nodes, run in new thread and use queue to get results
q: queue.Queue = queue.Queue()
@@ -417,19 +422,22 @@ class GraphEngine:
for edge in edge_mappings:
if (
edge.target_node_id not in self.graph.node_parallel_mapping
or self.graph.node_parallel_mapping.get(edge.target_node_id, '') != parallel_id
or self.graph.node_parallel_mapping.get(edge.target_node_id, "") != parallel_id
):
continue
futures.append(
self.thread_pool.submit(self._run_parallel_node, **{
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
'q': q,
'parallel_id': parallel_id,
'parallel_start_node_id': edge.target_node_id,
'parent_parallel_id': in_parallel_id,
'parent_parallel_start_node_id': parallel_start_node_id,
})
self.thread_pool.submit(
self._run_parallel_node,
**{
"flask_app": current_app._get_current_object(), # type: ignore[attr-defined]
"q": q,
"parallel_id": parallel_id,
"parallel_start_node_id": edge.target_node_id,
"parent_parallel_id": in_parallel_id,
"parent_parallel_start_node_id": parallel_start_node_id,
},
)
)
succeeded_count = 0
@@ -451,7 +459,7 @@ class GraphEngine:
raise GraphRunFailedError(event.error)
except queue.Empty:
continue
# wait all threads
wait(futures)
@@ -461,72 +469,80 @@ class GraphEngine:
yield final_node_id
def _run_parallel_node(
self,
flask_app: Flask,
q: queue.Queue,
parallel_id: str,
parallel_start_node_id: str,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
self,
flask_app: Flask,
q: queue.Queue,
parallel_id: str,
parallel_start_node_id: str,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
) -> None:
"""
Run parallel nodes
"""
with flask_app.app_context():
try:
q.put(ParallelBranchRunStartedEvent(
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
))
q.put(
ParallelBranchRunStartedEvent(
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
)
# run node
generator = self._run(
start_node_id=parallel_start_node_id,
in_parallel_id=parallel_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
for item in generator:
q.put(item)
# trigger graph run success event
q.put(ParallelBranchRunSucceededEvent(
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
))
q.put(
ParallelBranchRunSucceededEvent(
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
)
except GraphRunFailedError as e:
q.put(ParallelBranchRunFailedEvent(
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
error=e.error
))
q.put(
ParallelBranchRunFailedEvent(
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
error=e.error,
)
)
except Exception as e:
logger.exception("Unknown Error when generating in parallel")
q.put(ParallelBranchRunFailedEvent(
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
error=str(e)
))
q.put(
ParallelBranchRunFailedEvent(
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
error=str(e),
)
)
finally:
db.session.remove()
def _run_node(
self,
node_instance: BaseNode,
route_node_state: RouteNodeState,
parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
self,
node_instance: BaseNode,
route_node_state: RouteNodeState,
parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
) -> Generator[GraphEngineEvent, None, None]:
"""
Run node
@@ -542,7 +558,7 @@ class GraphEngine:
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
db.session.close()
@@ -567,7 +583,7 @@ class GraphEngine:
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or 'Unknown error.',
error=route_node_state.failed_reason or "Unknown error.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
@@ -576,7 +592,7 @@ class GraphEngine:
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
@@ -596,7 +612,7 @@ class GraphEngine:
self._append_variables_recursively(
node_id=node_instance.node_id,
variable_key_list=[variable_key],
variable_value=variable_value
variable_value=variable_value,
)
# add parallel info to run result metadata
@@ -608,7 +624,9 @@ class GraphEngine:
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
if parent_parallel_id and parent_parallel_start_node_id:
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = parent_parallel_start_node_id
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
parent_parallel_start_node_id
)
yield NodeRunSucceededEvent(
id=node_instance.id,
@@ -619,7 +637,7 @@ class GraphEngine:
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
break
@@ -635,7 +653,7 @@ class GraphEngine:
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
@@ -649,7 +667,7 @@ class GraphEngine:
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
except GenerateTaskStoppedException:
# trigger node run failed event
@@ -665,7 +683,7 @@ class GraphEngine:
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
return
except Exception as e:
@@ -674,10 +692,7 @@ class GraphEngine:
finally:
db.session.close()
def _append_variables_recursively(self,
node_id: str,
variable_key_list: list[str],
variable_value: VariableValue):
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
"""
Append variables recursively
:param node_id: node id
@@ -685,10 +700,7 @@ class GraphEngine:
:param variable_value: variable value
:return:
"""
self.graph_runtime_state.variable_pool.add(
[node_id] + variable_key_list,
variable_value
)
self.graph_runtime_state.variable_pool.add([node_id] + variable_key_list, variable_value)
# if variable_value is a dict, then recursively append variables
if isinstance(variable_value, dict):
@@ -696,9 +708,7 @@ class GraphEngine:
# construct new key list
new_key_list = variable_key_list + [key]
self._append_variables_recursively(
node_id=node_id,
variable_key_list=new_key_list,
variable_value=value
node_id=node_id, variable_key_list=new_key_list, variable_value=value
)
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: