chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -25,18 +25,11 @@ class EndNode(BaseNode):
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
outputs[variable_selector.variable] = value
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=outputs,
outputs=outputs
)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=outputs, outputs=outputs)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: EndNodeData
cls, graph_config: Mapping[str, Any], node_id: str, node_data: EndNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping

View File

@@ -3,13 +3,13 @@ from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam
class EndStreamGeneratorRouter:
@classmethod
def init(cls,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_parallel_mapping: dict[str, str]
) -> EndStreamParam:
def init(
cls,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_parallel_mapping: dict[str, str],
) -> EndStreamParam:
"""
Get stream generate routes.
:return:
@@ -17,7 +17,7 @@ class EndStreamGeneratorRouter:
# parse stream output node value selector of end nodes
end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {}
for end_node_id, node_config in node_id_config_mapping.items():
if not node_config.get('data', {}).get('type') == NodeType.END.value:
if not node_config.get("data", {}).get("type") == NodeType.END.value:
continue
# skip end node in parallel
@@ -33,18 +33,18 @@ class EndStreamGeneratorRouter:
end_dependencies = cls._fetch_ends_dependencies(
end_node_ids=end_node_ids,
reverse_edge_mapping=reverse_edge_mapping,
node_id_config_mapping=node_id_config_mapping
node_id_config_mapping=node_id_config_mapping,
)
return EndStreamParam(
end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping,
end_dependencies=end_dependencies
end_dependencies=end_dependencies,
)
@classmethod
def extract_stream_variable_selector_from_node_data(cls,
node_id_config_mapping: dict[str, dict],
node_data: EndNodeData) -> list[list[str]]:
def extract_stream_variable_selector_from_node_data(
cls, node_id_config_mapping: dict[str, dict], node_data: EndNodeData
) -> list[list[str]]:
"""
Extract stream variable selector from node data
:param node_id_config_mapping: node id config mapping
@@ -59,21 +59,22 @@ class EndStreamGeneratorRouter:
continue
node_id = variable_selector.value_selector[0]
if node_id != 'sys' and node_id in node_id_config_mapping:
if node_id != "sys" and node_id in node_id_config_mapping:
node = node_id_config_mapping[node_id]
node_type = node.get('data', {}).get('type')
node_type = node.get("data", {}).get("type")
if (
variable_selector.value_selector not in value_selectors
and node_type == NodeType.LLM.value
and variable_selector.value_selector[1] == 'text'
and node_type == NodeType.LLM.value
and variable_selector.value_selector[1] == "text"
):
value_selectors.append(variable_selector.value_selector)
return value_selectors
@classmethod
def _extract_stream_variable_selector(cls, node_id_config_mapping: dict[str, dict], config: dict) \
-> list[list[str]]:
def _extract_stream_variable_selector(
cls, node_id_config_mapping: dict[str, dict], config: dict
) -> list[list[str]]:
"""
Extract stream variable selector from node config
:param node_id_config_mapping: node id config mapping
@@ -84,11 +85,12 @@ class EndStreamGeneratorRouter:
return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data)
@classmethod
def _fetch_ends_dependencies(cls,
end_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_id_config_mapping: dict[str, dict]
) -> dict[str, list[str]]:
def _fetch_ends_dependencies(
cls,
end_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_id_config_mapping: dict[str, dict],
) -> dict[str, list[str]]:
"""
Fetch end dependencies
:param end_node_ids: end node ids
@@ -106,20 +108,21 @@ class EndStreamGeneratorRouter:
end_node_id=end_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
end_dependencies=end_dependencies
end_dependencies=end_dependencies,
)
return end_dependencies
@classmethod
def _recursive_fetch_end_dependencies(cls,
current_node_id: str,
end_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]],
# type: ignore[name-defined]
end_dependencies: dict[str, list[str]]
) -> None:
def _recursive_fetch_end_dependencies(
cls,
current_node_id: str,
end_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]],
# type: ignore[name-defined]
end_dependencies: dict[str, list[str]],
) -> None:
"""
Recursive fetch end dependencies
:param current_node_id: current node id
@@ -132,10 +135,10 @@ class EndStreamGeneratorRouter:
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type')
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
if source_node_type in (
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER,
):
end_dependencies[end_node_id].append(source_node_id)
else:
@@ -144,5 +147,5 @@ class EndStreamGeneratorRouter:
end_node_id=end_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
end_dependencies=end_dependencies
end_dependencies=end_dependencies,
)

View File

@@ -15,7 +15,6 @@ logger = logging.getLogger(__name__)
class EndStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
super().__init__(graph, variable_pool)
self.end_stream_param = graph.end_stream_param
@@ -26,9 +25,7 @@ class EndStreamProcessor(StreamProcessor):
self.has_outputed = False
self.outputed_node_ids = set()
def process(self,
generator: Generator[GraphEngineEvent, None, None]
) -> Generator[GraphEngineEvent, None, None]:
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
for event in generator:
if isinstance(event, NodeRunStartedEvent):
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
@@ -38,7 +35,7 @@ class EndStreamProcessor(StreamProcessor):
elif isinstance(event, NodeRunStreamChunkEvent):
if event.in_iteration_id:
if self.has_outputed and event.node_id not in self.outputed_node_ids:
event.chunk_content = '\n' + event.chunk_content
event.chunk_content = "\n" + event.chunk_content
self.outputed_node_ids.add(event.node_id)
self.has_outputed = True
@@ -51,13 +48,13 @@ class EndStreamProcessor(StreamProcessor):
]
else:
stream_out_end_node_ids = self._get_stream_out_end_node_ids(event)
self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
] = stream_out_end_node_ids
self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = (
stream_out_end_node_ids
)
if stream_out_end_node_ids:
if self.has_outputed and event.node_id not in self.outputed_node_ids:
event.chunk_content = '\n' + event.chunk_content
event.chunk_content = "\n" + event.chunk_content
self.outputed_node_ids.add(event.node_id)
self.has_outputed = True
@@ -86,9 +83,9 @@ class EndStreamProcessor(StreamProcessor):
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
def _generate_stream_outputs_when_node_finished(self,
event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
def _generate_stream_outputs_when_node_finished(
self, event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
"""
Generate stream outputs.
:param event: node run succeeded event
@@ -96,10 +93,12 @@ class EndStreamProcessor(StreamProcessor):
"""
for end_node_id, position in self.route_position.items():
# all depends on end node id not in rest node ids
if (event.route_node_state.node_id != end_node_id
and (end_node_id not in self.rest_node_ids
or not all(dep_id not in self.rest_node_ids
for dep_id in self.end_stream_param.end_dependencies[end_node_id]))):
if event.route_node_state.node_id != end_node_id and (
end_node_id not in self.rest_node_ids
or not all(
dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id]
)
):
continue
route_position = self.route_position[end_node_id]
@@ -116,9 +115,7 @@ class EndStreamProcessor(StreamProcessor):
if not value_selector:
continue
value = self.variable_pool.get(
value_selector
)
value = self.variable_pool.get(value_selector)
if value is None:
break
@@ -128,7 +125,7 @@ class EndStreamProcessor(StreamProcessor):
if text:
current_node_id = value_selector[0]
if self.has_outputed and current_node_id not in self.outputed_node_ids:
text = '\n' + text
text = "\n" + text
self.outputed_node_ids.add(current_node_id)
self.has_outputed = True
@@ -165,8 +162,7 @@ class EndStreamProcessor(StreamProcessor):
continue
# all depends on end node id not in rest node ids
if all(dep_id not in self.rest_node_ids
for dep_id in self.end_stream_param.end_dependencies[end_node_id]):
if all(dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id]):
if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]):
continue
@@ -178,7 +174,7 @@ class EndStreamProcessor(StreamProcessor):
break
position += 1
if not value_selector:
continue

View File

@@ -8,6 +8,7 @@ class EndNodeData(BaseNodeData):
"""
END Node Data.
"""
outputs: list[VariableSelector]
@@ -15,11 +16,10 @@ class EndStreamParam(BaseModel):
"""
EndStreamParam entity
"""
end_dependencies: dict[str, list[str]] = Field(
...,
description="end dependencies (end node id -> dependent node ids)"
..., description="end dependencies (end node id -> dependent node ids)"
)
end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field(
...,
description="end stream variable selector mapping (end node id -> stream variable selectors)"
..., description="end stream variable selector mapping (end node id -> stream variable selectors)"
)