chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -25,18 +25,11 @@ class EndNode(BaseNode):
|
||||
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
|
||||
outputs[variable_selector.variable] = value
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=outputs,
|
||||
outputs=outputs
|
||||
)
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=outputs, outputs=outputs)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: EndNodeData
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: EndNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
|
@@ -3,13 +3,13 @@ from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam
|
||||
|
||||
|
||||
class EndStreamGeneratorRouter:
|
||||
|
||||
@classmethod
|
||||
def init(cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_parallel_mapping: dict[str, str]
|
||||
) -> EndStreamParam:
|
||||
def init(
|
||||
cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_parallel_mapping: dict[str, str],
|
||||
) -> EndStreamParam:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
@@ -17,7 +17,7 @@ class EndStreamGeneratorRouter:
|
||||
# parse stream output node value selector of end nodes
|
||||
end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {}
|
||||
for end_node_id, node_config in node_id_config_mapping.items():
|
||||
if not node_config.get('data', {}).get('type') == NodeType.END.value:
|
||||
if not node_config.get("data", {}).get("type") == NodeType.END.value:
|
||||
continue
|
||||
|
||||
# skip end node in parallel
|
||||
@@ -33,18 +33,18 @@ class EndStreamGeneratorRouter:
|
||||
end_dependencies = cls._fetch_ends_dependencies(
|
||||
end_node_ids=end_node_ids,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
node_id_config_mapping=node_id_config_mapping
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
)
|
||||
|
||||
return EndStreamParam(
|
||||
end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping,
|
||||
end_dependencies=end_dependencies
|
||||
end_dependencies=end_dependencies,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_stream_variable_selector_from_node_data(cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
node_data: EndNodeData) -> list[list[str]]:
|
||||
def extract_stream_variable_selector_from_node_data(
|
||||
cls, node_id_config_mapping: dict[str, dict], node_data: EndNodeData
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
Extract stream variable selector from node data
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
@@ -59,21 +59,22 @@ class EndStreamGeneratorRouter:
|
||||
continue
|
||||
|
||||
node_id = variable_selector.value_selector[0]
|
||||
if node_id != 'sys' and node_id in node_id_config_mapping:
|
||||
if node_id != "sys" and node_id in node_id_config_mapping:
|
||||
node = node_id_config_mapping[node_id]
|
||||
node_type = node.get('data', {}).get('type')
|
||||
node_type = node.get("data", {}).get("type")
|
||||
if (
|
||||
variable_selector.value_selector not in value_selectors
|
||||
and node_type == NodeType.LLM.value
|
||||
and variable_selector.value_selector[1] == 'text'
|
||||
and node_type == NodeType.LLM.value
|
||||
and variable_selector.value_selector[1] == "text"
|
||||
):
|
||||
value_selectors.append(variable_selector.value_selector)
|
||||
|
||||
return value_selectors
|
||||
|
||||
@classmethod
|
||||
def _extract_stream_variable_selector(cls, node_id_config_mapping: dict[str, dict], config: dict) \
|
||||
-> list[list[str]]:
|
||||
def _extract_stream_variable_selector(
|
||||
cls, node_id_config_mapping: dict[str, dict], config: dict
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
Extract stream variable selector from node config
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
@@ -84,11 +85,12 @@ class EndStreamGeneratorRouter:
|
||||
return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data)
|
||||
|
||||
@classmethod
|
||||
def _fetch_ends_dependencies(cls,
|
||||
end_node_ids: list[str],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_id_config_mapping: dict[str, dict]
|
||||
) -> dict[str, list[str]]:
|
||||
def _fetch_ends_dependencies(
|
||||
cls,
|
||||
end_node_ids: list[str],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
Fetch end dependencies
|
||||
:param end_node_ids: end node ids
|
||||
@@ -106,20 +108,21 @@ class EndStreamGeneratorRouter:
|
||||
end_node_id=end_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
end_dependencies=end_dependencies
|
||||
end_dependencies=end_dependencies,
|
||||
)
|
||||
|
||||
return end_dependencies
|
||||
|
||||
@classmethod
|
||||
def _recursive_fetch_end_dependencies(cls,
|
||||
current_node_id: str,
|
||||
end_node_id: str,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]],
|
||||
# type: ignore[name-defined]
|
||||
end_dependencies: dict[str, list[str]]
|
||||
) -> None:
|
||||
def _recursive_fetch_end_dependencies(
|
||||
cls,
|
||||
current_node_id: str,
|
||||
end_node_id: str,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]],
|
||||
# type: ignore[name-defined]
|
||||
end_dependencies: dict[str, list[str]],
|
||||
) -> None:
|
||||
"""
|
||||
Recursive fetch end dependencies
|
||||
:param current_node_id: current node id
|
||||
@@ -132,10 +135,10 @@ class EndStreamGeneratorRouter:
|
||||
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
|
||||
for edge in reverse_edges:
|
||||
source_node_id = edge.source_node_id
|
||||
source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type')
|
||||
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
||||
if source_node_type in (
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
):
|
||||
end_dependencies[end_node_id].append(source_node_id)
|
||||
else:
|
||||
@@ -144,5 +147,5 @@ class EndStreamGeneratorRouter:
|
||||
end_node_id=end_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
end_dependencies=end_dependencies
|
||||
end_dependencies=end_dependencies,
|
||||
)
|
||||
|
@@ -15,7 +15,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EndStreamProcessor(StreamProcessor):
|
||||
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
super().__init__(graph, variable_pool)
|
||||
self.end_stream_param = graph.end_stream_param
|
||||
@@ -26,9 +25,7 @@ class EndStreamProcessor(StreamProcessor):
|
||||
self.has_outputed = False
|
||||
self.outputed_node_ids = set()
|
||||
|
||||
def process(self,
|
||||
generator: Generator[GraphEngineEvent, None, None]
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
|
||||
for event in generator:
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
|
||||
@@ -38,7 +35,7 @@ class EndStreamProcessor(StreamProcessor):
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
if event.in_iteration_id:
|
||||
if self.has_outputed and event.node_id not in self.outputed_node_ids:
|
||||
event.chunk_content = '\n' + event.chunk_content
|
||||
event.chunk_content = "\n" + event.chunk_content
|
||||
|
||||
self.outputed_node_ids.add(event.node_id)
|
||||
self.has_outputed = True
|
||||
@@ -51,13 +48,13 @@ class EndStreamProcessor(StreamProcessor):
|
||||
]
|
||||
else:
|
||||
stream_out_end_node_ids = self._get_stream_out_end_node_ids(event)
|
||||
self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
] = stream_out_end_node_ids
|
||||
self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = (
|
||||
stream_out_end_node_ids
|
||||
)
|
||||
|
||||
if stream_out_end_node_ids:
|
||||
if self.has_outputed and event.node_id not in self.outputed_node_ids:
|
||||
event.chunk_content = '\n' + event.chunk_content
|
||||
event.chunk_content = "\n" + event.chunk_content
|
||||
|
||||
self.outputed_node_ids.add(event.node_id)
|
||||
self.has_outputed = True
|
||||
@@ -86,9 +83,9 @@ class EndStreamProcessor(StreamProcessor):
|
||||
self.rest_node_ids = self.graph.node_ids.copy()
|
||||
self.current_stream_chunk_generating_node_ids = {}
|
||||
|
||||
def _generate_stream_outputs_when_node_finished(self,
|
||||
event: NodeRunSucceededEvent
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
def _generate_stream_outputs_when_node_finished(
|
||||
self, event: NodeRunSucceededEvent
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:param event: node run succeeded event
|
||||
@@ -96,10 +93,12 @@ class EndStreamProcessor(StreamProcessor):
|
||||
"""
|
||||
for end_node_id, position in self.route_position.items():
|
||||
# all depends on end node id not in rest node ids
|
||||
if (event.route_node_state.node_id != end_node_id
|
||||
and (end_node_id not in self.rest_node_ids
|
||||
or not all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.end_stream_param.end_dependencies[end_node_id]))):
|
||||
if event.route_node_state.node_id != end_node_id and (
|
||||
end_node_id not in self.rest_node_ids
|
||||
or not all(
|
||||
dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id]
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
route_position = self.route_position[end_node_id]
|
||||
@@ -116,9 +115,7 @@ class EndStreamProcessor(StreamProcessor):
|
||||
if not value_selector:
|
||||
continue
|
||||
|
||||
value = self.variable_pool.get(
|
||||
value_selector
|
||||
)
|
||||
value = self.variable_pool.get(value_selector)
|
||||
|
||||
if value is None:
|
||||
break
|
||||
@@ -128,7 +125,7 @@ class EndStreamProcessor(StreamProcessor):
|
||||
if text:
|
||||
current_node_id = value_selector[0]
|
||||
if self.has_outputed and current_node_id not in self.outputed_node_ids:
|
||||
text = '\n' + text
|
||||
text = "\n" + text
|
||||
|
||||
self.outputed_node_ids.add(current_node_id)
|
||||
self.has_outputed = True
|
||||
@@ -165,8 +162,7 @@ class EndStreamProcessor(StreamProcessor):
|
||||
continue
|
||||
|
||||
# all depends on end node id not in rest node ids
|
||||
if all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.end_stream_param.end_dependencies[end_node_id]):
|
||||
if all(dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id]):
|
||||
if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]):
|
||||
continue
|
||||
|
||||
@@ -178,7 +174,7 @@ class EndStreamProcessor(StreamProcessor):
|
||||
break
|
||||
|
||||
position += 1
|
||||
|
||||
|
||||
if not value_selector:
|
||||
continue
|
||||
|
||||
|
@@ -8,6 +8,7 @@ class EndNodeData(BaseNodeData):
|
||||
"""
|
||||
END Node Data.
|
||||
"""
|
||||
|
||||
outputs: list[VariableSelector]
|
||||
|
||||
|
||||
@@ -15,11 +16,10 @@ class EndStreamParam(BaseModel):
|
||||
"""
|
||||
EndStreamParam entity
|
||||
"""
|
||||
|
||||
end_dependencies: dict[str, list[str]] = Field(
|
||||
...,
|
||||
description="end dependencies (end node id -> dependent node ids)"
|
||||
..., description="end dependencies (end node id -> dependent node ids)"
|
||||
)
|
||||
end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field(
|
||||
...,
|
||||
description="end stream variable selector mapping (end node id -> stream variable selectors)"
|
||||
..., description="end stream variable selector mapping (end node id -> stream variable selectors)"
|
||||
)
|
||||
|
Reference in New Issue
Block a user