chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -29,14 +29,12 @@ class AnswerNode(BaseNode):
# generate routes
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data)
answer = ''
answer = ""
for part in generate_routes:
if part.type == GenerateRouteChunk.ChunkType.VAR:
part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector
value = self.graph_runtime_state.variable_pool.get(
value_selector
)
value = self.graph_runtime_state.variable_pool.get(value_selector)
if value:
answer += value.markdown
@@ -44,19 +42,11 @@ class AnswerNode(BaseNode):
part = cast(TextGenerateRouteChunk, part)
answer += part.text
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"answer": answer
}
)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer})
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: AnswerNodeData
cls, graph_config: Mapping[str, Any], node_id: str, node_data: AnswerNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -73,6 +63,6 @@ class AnswerNode(BaseNode):
variable_mapping = {}
for variable_selector in variable_selectors:
variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
return variable_mapping

View File

@@ -1,4 +1,3 @@
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.entities import (
@@ -12,12 +11,12 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
class AnswerStreamGeneratorRouter:
@classmethod
def init(cls,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined]
) -> AnswerStreamGenerateRoute:
def init(
cls,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
) -> AnswerStreamGenerateRoute:
"""
Get stream generate routes.
:return:
@@ -25,7 +24,7 @@ class AnswerStreamGeneratorRouter:
# parse stream output node value selectors of answer nodes
answer_generate_route: dict[str, list[GenerateRouteChunk]] = {}
for answer_node_id, node_config in node_id_config_mapping.items():
if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value:
if not node_config.get("data", {}).get("type") == NodeType.ANSWER.value:
continue
# get generate route for stream output
@@ -37,12 +36,11 @@ class AnswerStreamGeneratorRouter:
answer_dependencies = cls._fetch_answers_dependencies(
answer_node_ids=answer_node_ids,
reverse_edge_mapping=reverse_edge_mapping,
node_id_config_mapping=node_id_config_mapping
node_id_config_mapping=node_id_config_mapping,
)
return AnswerStreamGenerateRoute(
answer_generate_route=answer_generate_route,
answer_dependencies=answer_dependencies
answer_generate_route=answer_generate_route, answer_dependencies=answer_dependencies
)
@classmethod
@@ -56,8 +54,7 @@ class AnswerStreamGeneratorRouter:
variable_selectors = variable_template_parser.extract_variable_selectors()
value_selector_mapping = {
variable_selector.variable: variable_selector.value_selector
for variable_selector in variable_selectors
variable_selector.variable: variable_selector.value_selector for variable_selector in variable_selectors
}
variable_keys = list(value_selector_mapping.keys())
@@ -71,21 +68,17 @@ class AnswerStreamGeneratorRouter:
template = node_data.answer
for var in variable_keys:
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
template = template.replace(f"{{{{{var}}}}}", f"Ω{{{{{var}}}}}Ω")
generate_routes: list[GenerateRouteChunk] = []
for part in template.split('Ω'):
for part in template.split("Ω"):
if part:
if cls._is_variable(part, variable_keys):
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
var_key = part.replace("Ω", "").replace("{{", "").replace("}}", "")
value_selector = value_selector_mapping[var_key]
generate_routes.append(VarGenerateRouteChunk(
value_selector=value_selector
))
generate_routes.append(VarGenerateRouteChunk(value_selector=value_selector))
else:
generate_routes.append(TextGenerateRouteChunk(
text=part
))
generate_routes.append(TextGenerateRouteChunk(text=part))
return generate_routes
@@ -101,15 +94,16 @@ class AnswerStreamGeneratorRouter:
@classmethod
def _is_variable(cls, part, variable_keys):
cleaned_part = part.replace('{{', '').replace('}}', '')
return part.startswith('{{') and cleaned_part in variable_keys
cleaned_part = part.replace("{{", "").replace("}}", "")
return part.startswith("{{") and cleaned_part in variable_keys
@classmethod
def _fetch_answers_dependencies(cls,
answer_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_id_config_mapping: dict[str, dict]
) -> dict[str, list[str]]:
def _fetch_answers_dependencies(
cls,
answer_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_id_config_mapping: dict[str, dict],
) -> dict[str, list[str]]:
"""
Fetch answer dependencies
:param answer_node_ids: answer node ids
@@ -127,19 +121,20 @@ class AnswerStreamGeneratorRouter:
answer_node_id=answer_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies
answer_dependencies=answer_dependencies,
)
return answer_dependencies
@classmethod
def _recursive_fetch_answer_dependencies(cls,
current_node_id: str,
answer_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
answer_dependencies: dict[str, list[str]]
) -> None:
def _recursive_fetch_answer_dependencies(
cls,
current_node_id: str,
answer_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
answer_dependencies: dict[str, list[str]],
) -> None:
"""
Recursive fetch answer dependencies
:param current_node_id: current node id
@@ -152,11 +147,11 @@ class AnswerStreamGeneratorRouter:
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type')
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
if source_node_type in (
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER,
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER,
):
answer_dependencies[answer_node_id].append(source_node_id)
else:
@@ -165,5 +160,5 @@ class AnswerStreamGeneratorRouter:
answer_node_id=answer_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies
answer_dependencies=answer_dependencies,
)

View File

@@ -18,7 +18,6 @@ logger = logging.getLogger(__name__)
class AnswerStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
super().__init__(graph, variable_pool)
self.generate_routes = graph.answer_stream_generate_routes
@@ -27,9 +26,7 @@ class AnswerStreamProcessor(StreamProcessor):
self.route_position[answer_node_id] = 0
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
def process(self,
generator: Generator[GraphEngineEvent, None, None]
) -> Generator[GraphEngineEvent, None, None]:
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
for event in generator:
if isinstance(event, NodeRunStartedEvent):
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
@@ -47,9 +44,9 @@ class AnswerStreamProcessor(StreamProcessor):
]
else:
stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event)
self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
] = stream_out_answer_node_ids
self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = (
stream_out_answer_node_ids
)
for _ in stream_out_answer_node_ids:
yield event
@@ -77,9 +74,9 @@ class AnswerStreamProcessor(StreamProcessor):
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
def _generate_stream_outputs_when_node_finished(self,
event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
def _generate_stream_outputs_when_node_finished(
self, event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
"""
Generate stream outputs.
:param event: node run succeeded event
@@ -87,10 +84,13 @@ class AnswerStreamProcessor(StreamProcessor):
"""
for answer_node_id, position in self.route_position.items():
# all depends on answer node id not in rest node ids
if (event.route_node_state.node_id != answer_node_id
and (answer_node_id not in self.rest_node_ids
or not all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
if event.route_node_state.node_id != answer_node_id and (
answer_node_id not in self.rest_node_ids
or not all(
dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]
)
):
continue
route_position = self.route_position[answer_node_id]
@@ -115,9 +115,7 @@ class AnswerStreamProcessor(StreamProcessor):
if not value_selector:
break
value = self.variable_pool.get(
value_selector
)
value = self.variable_pool.get(value_selector)
if value is None:
break
@@ -158,8 +156,9 @@ class AnswerStreamProcessor(StreamProcessor):
continue
# all depends on answer node id not in rest node ids
if all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
if all(
dep_id not in self.rest_node_ids for dep_id in self.generate_routes.answer_dependencies[answer_node_id]
):
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
continue
@@ -213,7 +212,7 @@ class AnswerStreamProcessor(StreamProcessor):
return None
if isinstance(value, dict):
if '__variant' in value and value['__variant'] == FileVar.__name__:
if "__variant" in value and value["__variant"] == FileVar.__name__:
return value
elif isinstance(value, FileVar):
return value.to_dict()

View File

@@ -7,16 +7,13 @@ from core.workflow.graph_engine.entities.graph import Graph
class StreamProcessor(ABC):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
self.graph = graph
self.variable_pool = variable_pool
self.rest_node_ids = graph.node_ids.copy()
@abstractmethod
def process(self,
generator: Generator[GraphEngineEvent, None, None]
) -> Generator[GraphEngineEvent, None, None]:
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
raise NotImplementedError
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
@@ -35,9 +32,11 @@ class StreamProcessor(ABC):
reachable_node_ids = []
unreachable_first_node_ids = []
for edge in self.graph.edge_mapping[finished_node_id]:
if (edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify):
if (
edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify
):
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
continue
else:

View File

@@ -9,6 +9,7 @@ class AnswerNodeData(BaseNodeData):
"""
Answer Node Data.
"""
answer: str = Field(..., description="answer template string")
@@ -28,6 +29,7 @@ class VarGenerateRouteChunk(GenerateRouteChunk):
"""
Var Generate Route Chunk.
"""
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR
"""generate route chunk type"""
value_selector: list[str] = Field(..., description="value selector")
@@ -37,6 +39,7 @@ class TextGenerateRouteChunk(GenerateRouteChunk):
"""
Text Generate Route Chunk.
"""
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT
"""generate route chunk type"""
text: str = Field(..., description="text")
@@ -52,11 +55,10 @@ class AnswerStreamGenerateRoute(BaseModel):
"""
AnswerStreamGenerateRoute entity
"""
answer_dependencies: dict[str, list[str]] = Field(
...,
description="answer dependencies (answer node id -> dependent answer node ids)"
..., description="answer dependencies (answer node id -> dependent answer node ids)"
)
answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field(
...,
description="answer generate route (answer node id -> generate route chunks)"
..., description="answer generate route (answer node id -> generate route chunks)"
)