chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -29,14 +29,12 @@ class AnswerNode(BaseNode):
|
||||
# generate routes
|
||||
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data)
|
||||
|
||||
answer = ''
|
||||
answer = ""
|
||||
for part in generate_routes:
|
||||
if part.type == GenerateRouteChunk.ChunkType.VAR:
|
||||
part = cast(VarGenerateRouteChunk, part)
|
||||
value_selector = part.value_selector
|
||||
value = self.graph_runtime_state.variable_pool.get(
|
||||
value_selector
|
||||
)
|
||||
value = self.graph_runtime_state.variable_pool.get(value_selector)
|
||||
|
||||
if value:
|
||||
answer += value.markdown
|
||||
@@ -44,19 +42,11 @@ class AnswerNode(BaseNode):
|
||||
part = cast(TextGenerateRouteChunk, part)
|
||||
answer += part.text
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"answer": answer
|
||||
}
|
||||
)
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer})
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: AnswerNodeData
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: AnswerNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -73,6 +63,6 @@ class AnswerNode(BaseNode):
|
||||
|
||||
variable_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector
|
||||
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
return variable_mapping
|
||||
|
@@ -1,4 +1,3 @@
|
||||
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
@@ -12,12 +11,12 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
class AnswerStreamGeneratorRouter:
|
||||
|
||||
@classmethod
|
||||
def init(cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined]
|
||||
) -> AnswerStreamGenerateRoute:
|
||||
def init(
|
||||
cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
) -> AnswerStreamGenerateRoute:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
@@ -25,7 +24,7 @@ class AnswerStreamGeneratorRouter:
|
||||
# parse stream output node value selectors of answer nodes
|
||||
answer_generate_route: dict[str, list[GenerateRouteChunk]] = {}
|
||||
for answer_node_id, node_config in node_id_config_mapping.items():
|
||||
if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value:
|
||||
if not node_config.get("data", {}).get("type") == NodeType.ANSWER.value:
|
||||
continue
|
||||
|
||||
# get generate route for stream output
|
||||
@@ -37,12 +36,11 @@ class AnswerStreamGeneratorRouter:
|
||||
answer_dependencies = cls._fetch_answers_dependencies(
|
||||
answer_node_ids=answer_node_ids,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
node_id_config_mapping=node_id_config_mapping
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
)
|
||||
|
||||
return AnswerStreamGenerateRoute(
|
||||
answer_generate_route=answer_generate_route,
|
||||
answer_dependencies=answer_dependencies
|
||||
answer_generate_route=answer_generate_route, answer_dependencies=answer_dependencies
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -56,8 +54,7 @@ class AnswerStreamGeneratorRouter:
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
value_selector_mapping = {
|
||||
variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in variable_selectors
|
||||
variable_selector.variable: variable_selector.value_selector for variable_selector in variable_selectors
|
||||
}
|
||||
|
||||
variable_keys = list(value_selector_mapping.keys())
|
||||
@@ -71,21 +68,17 @@ class AnswerStreamGeneratorRouter:
|
||||
|
||||
template = node_data.answer
|
||||
for var in variable_keys:
|
||||
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
|
||||
template = template.replace(f"{{{{{var}}}}}", f"Ω{{{{{var}}}}}Ω")
|
||||
|
||||
generate_routes: list[GenerateRouteChunk] = []
|
||||
for part in template.split('Ω'):
|
||||
for part in template.split("Ω"):
|
||||
if part:
|
||||
if cls._is_variable(part, variable_keys):
|
||||
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
|
||||
var_key = part.replace("Ω", "").replace("{{", "").replace("}}", "")
|
||||
value_selector = value_selector_mapping[var_key]
|
||||
generate_routes.append(VarGenerateRouteChunk(
|
||||
value_selector=value_selector
|
||||
))
|
||||
generate_routes.append(VarGenerateRouteChunk(value_selector=value_selector))
|
||||
else:
|
||||
generate_routes.append(TextGenerateRouteChunk(
|
||||
text=part
|
||||
))
|
||||
generate_routes.append(TextGenerateRouteChunk(text=part))
|
||||
|
||||
return generate_routes
|
||||
|
||||
@@ -101,15 +94,16 @@ class AnswerStreamGeneratorRouter:
|
||||
|
||||
@classmethod
|
||||
def _is_variable(cls, part, variable_keys):
|
||||
cleaned_part = part.replace('{{', '').replace('}}', '')
|
||||
return part.startswith('{{') and cleaned_part in variable_keys
|
||||
cleaned_part = part.replace("{{", "").replace("}}", "")
|
||||
return part.startswith("{{") and cleaned_part in variable_keys
|
||||
|
||||
@classmethod
|
||||
def _fetch_answers_dependencies(cls,
|
||||
answer_node_ids: list[str],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_id_config_mapping: dict[str, dict]
|
||||
) -> dict[str, list[str]]:
|
||||
def _fetch_answers_dependencies(
|
||||
cls,
|
||||
answer_node_ids: list[str],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
Fetch answer dependencies
|
||||
:param answer_node_ids: answer node ids
|
||||
@@ -127,19 +121,20 @@ class AnswerStreamGeneratorRouter:
|
||||
answer_node_id=answer_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
answer_dependencies=answer_dependencies
|
||||
answer_dependencies=answer_dependencies,
|
||||
)
|
||||
|
||||
return answer_dependencies
|
||||
|
||||
@classmethod
|
||||
def _recursive_fetch_answer_dependencies(cls,
|
||||
current_node_id: str,
|
||||
answer_node_id: str,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
answer_dependencies: dict[str, list[str]]
|
||||
) -> None:
|
||||
def _recursive_fetch_answer_dependencies(
|
||||
cls,
|
||||
current_node_id: str,
|
||||
answer_node_id: str,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
answer_dependencies: dict[str, list[str]],
|
||||
) -> None:
|
||||
"""
|
||||
Recursive fetch answer dependencies
|
||||
:param current_node_id: current node id
|
||||
@@ -152,11 +147,11 @@ class AnswerStreamGeneratorRouter:
|
||||
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
|
||||
for edge in reverse_edges:
|
||||
source_node_id = edge.source_node_id
|
||||
source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type')
|
||||
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
||||
if source_node_type in (
|
||||
NodeType.ANSWER.value,
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
NodeType.ANSWER.value,
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
):
|
||||
answer_dependencies[answer_node_id].append(source_node_id)
|
||||
else:
|
||||
@@ -165,5 +160,5 @@ class AnswerStreamGeneratorRouter:
|
||||
answer_node_id=answer_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
answer_dependencies=answer_dependencies
|
||||
answer_dependencies=answer_dependencies,
|
||||
)
|
||||
|
@@ -18,7 +18,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnswerStreamProcessor(StreamProcessor):
|
||||
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
super().__init__(graph, variable_pool)
|
||||
self.generate_routes = graph.answer_stream_generate_routes
|
||||
@@ -27,9 +26,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
self.route_position[answer_node_id] = 0
|
||||
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
||||
|
||||
def process(self,
|
||||
generator: Generator[GraphEngineEvent, None, None]
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
|
||||
for event in generator:
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
|
||||
@@ -47,9 +44,9 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
]
|
||||
else:
|
||||
stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event)
|
||||
self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
] = stream_out_answer_node_ids
|
||||
self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = (
|
||||
stream_out_answer_node_ids
|
||||
)
|
||||
|
||||
for _ in stream_out_answer_node_ids:
|
||||
yield event
|
||||
@@ -77,9 +74,9 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
self.rest_node_ids = self.graph.node_ids.copy()
|
||||
self.current_stream_chunk_generating_node_ids = {}
|
||||
|
||||
def _generate_stream_outputs_when_node_finished(self,
|
||||
event: NodeRunSucceededEvent
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
def _generate_stream_outputs_when_node_finished(
|
||||
self, event: NodeRunSucceededEvent
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:param event: node run succeeded event
|
||||
@@ -87,10 +84,13 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
"""
|
||||
for answer_node_id, position in self.route_position.items():
|
||||
# all depends on answer node id not in rest node ids
|
||||
if (event.route_node_state.node_id != answer_node_id
|
||||
and (answer_node_id not in self.rest_node_ids
|
||||
or not all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
|
||||
if event.route_node_state.node_id != answer_node_id and (
|
||||
answer_node_id not in self.rest_node_ids
|
||||
or not all(
|
||||
dep_id not in self.rest_node_ids
|
||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
route_position = self.route_position[answer_node_id]
|
||||
@@ -115,9 +115,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
if not value_selector:
|
||||
break
|
||||
|
||||
value = self.variable_pool.get(
|
||||
value_selector
|
||||
)
|
||||
value = self.variable_pool.get(value_selector)
|
||||
|
||||
if value is None:
|
||||
break
|
||||
@@ -158,8 +156,9 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
continue
|
||||
|
||||
# all depends on answer node id not in rest node ids
|
||||
if all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
|
||||
if all(
|
||||
dep_id not in self.rest_node_ids for dep_id in self.generate_routes.answer_dependencies[answer_node_id]
|
||||
):
|
||||
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
|
||||
continue
|
||||
|
||||
@@ -213,7 +212,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
return None
|
||||
|
||||
if isinstance(value, dict):
|
||||
if '__variant' in value and value['__variant'] == FileVar.__name__:
|
||||
if "__variant" in value and value["__variant"] == FileVar.__name__:
|
||||
return value
|
||||
elif isinstance(value, FileVar):
|
||||
return value.to_dict()
|
||||
|
@@ -7,16 +7,13 @@ from core.workflow.graph_engine.entities.graph import Graph
|
||||
|
||||
|
||||
class StreamProcessor(ABC):
|
||||
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
self.graph = graph
|
||||
self.variable_pool = variable_pool
|
||||
self.rest_node_ids = graph.node_ids.copy()
|
||||
|
||||
@abstractmethod
|
||||
def process(self,
|
||||
generator: Generator[GraphEngineEvent, None, None]
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
|
||||
@@ -35,9 +32,11 @@ class StreamProcessor(ABC):
|
||||
reachable_node_ids = []
|
||||
unreachable_first_node_ids = []
|
||||
for edge in self.graph.edge_mapping[finished_node_id]:
|
||||
if (edge.run_condition
|
||||
and edge.run_condition.branch_identify
|
||||
and run_result.edge_source_handle == edge.run_condition.branch_identify):
|
||||
if (
|
||||
edge.run_condition
|
||||
and edge.run_condition.branch_identify
|
||||
and run_result.edge_source_handle == edge.run_condition.branch_identify
|
||||
):
|
||||
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
continue
|
||||
else:
|
||||
|
@@ -9,6 +9,7 @@ class AnswerNodeData(BaseNodeData):
|
||||
"""
|
||||
Answer Node Data.
|
||||
"""
|
||||
|
||||
answer: str = Field(..., description="answer template string")
|
||||
|
||||
|
||||
@@ -28,6 +29,7 @@ class VarGenerateRouteChunk(GenerateRouteChunk):
|
||||
"""
|
||||
Var Generate Route Chunk.
|
||||
"""
|
||||
|
||||
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR
|
||||
"""generate route chunk type"""
|
||||
value_selector: list[str] = Field(..., description="value selector")
|
||||
@@ -37,6 +39,7 @@ class TextGenerateRouteChunk(GenerateRouteChunk):
|
||||
"""
|
||||
Text Generate Route Chunk.
|
||||
"""
|
||||
|
||||
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT
|
||||
"""generate route chunk type"""
|
||||
text: str = Field(..., description="text")
|
||||
@@ -52,11 +55,10 @@ class AnswerStreamGenerateRoute(BaseModel):
|
||||
"""
|
||||
AnswerStreamGenerateRoute entity
|
||||
"""
|
||||
|
||||
answer_dependencies: dict[str, list[str]] = Field(
|
||||
...,
|
||||
description="answer dependencies (answer node id -> dependent answer node ids)"
|
||||
..., description="answer dependencies (answer node id -> dependent answer node ids)"
|
||||
)
|
||||
answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field(
|
||||
...,
|
||||
description="answer generate route (answer node id -> generate route chunks)"
|
||||
..., description="answer generate route (answer node id -> generate route chunks)"
|
||||
)
|
||||
|
Reference in New Issue
Block a user