chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -7,21 +7,25 @@ class IterationNodeData(BaseIterationNodeData):
"""
Iteration Node Data.
"""
parent_loop_id: Optional[str] = None # redundant field, not used currently
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
parent_loop_id: Optional[str] = None # redundant field, not used currently
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
class IterationStartNodeData(BaseNodeData):
"""
Iteration Start Node Data.
"""
pass
class IterationState(BaseIterationState):
"""
Iteration State.
"""
outputs: list[Any] = None
current_output: Optional[Any] = None
@@ -29,6 +33,7 @@ class IterationState(BaseIterationState):
"""
Data.
"""
iterator_length: int
def get_last_output(self) -> Optional[Any]:
@@ -38,9 +43,9 @@ class IterationState(BaseIterationState):
if self.outputs:
return self.outputs[-1]
return None
def get_current_output(self) -> Optional[Any]:
"""
Get current output.
"""
return self.current_output
return self.current_output

View File

@@ -33,6 +33,7 @@ class IterationNode(BaseNode):
"""
Iteration Node.
"""
_node_data_cls = IterationNodeData
_node_type = NodeType.ITERATION
@@ -45,31 +46,26 @@ class IterationNode(BaseNode):
if not iterator_list_segment:
raise ValueError(f"Iterator variable {self.node_data.iterator_selector} not found")
iterator_list_value = iterator_list_segment.to_object()
if not isinstance(iterator_list_value, list):
raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
inputs = {
"iterator_selector": iterator_list_value
}
inputs = {"iterator_selector": iterator_list_value}
graph_config = self.graph_config
if not self.node_data.start_node_id:
raise ValueError(f'field start_node_id in iteration {self.node_id} not found')
raise ValueError(f"field start_node_id in iteration {self.node_id} not found")
root_node_id = self.node_data.start_node_id
# init graph
iteration_graph = Graph.init(
graph_config=graph_config,
root_node_id=root_node_id
)
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
if not iteration_graph:
raise ValueError('iteration graph not found')
raise ValueError("iteration graph not found")
leaf_node_ids = iteration_graph.get_leaf_node_ids()
iteration_leaf_node_ids = []
@@ -97,26 +93,21 @@ class IterationNode(BaseNode):
Condition(
variable_selector=[self.node_id, "index"],
comparison_operator="<",
value=str(len(iterator_list_value))
value=str(len(iterator_list_value)),
)
]
)
],
),
)
variable_pool = self.graph_runtime_state.variable_pool
# append iteration variable (item, index) to variable pool
variable_pool.add(
[self.node_id, 'index'],
0
)
variable_pool.add(
[self.node_id, 'item'],
iterator_list_value[0]
)
variable_pool.add([self.node_id, "index"], 0)
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
# init graph engine
from core.workflow.graph_engine.graph_engine import GraphEngine
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
@@ -130,7 +121,7 @@ class IterationNode(BaseNode):
graph_config=graph_config,
variable_pool=variable_pool,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
)
start_at = datetime.now(timezone.utc).replace(tzinfo=None)
@@ -142,10 +133,8 @@ class IterationNode(BaseNode):
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
metadata={
"iterator_length": len(iterator_list_value)
},
predecessor_node_id=self.previous_node_id
metadata={"iterator_length": len(iterator_list_value)},
predecessor_node_id=self.previous_node_id,
)
yield IterationRunNextEvent(
@@ -154,7 +143,7 @@ class IterationNode(BaseNode):
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=0,
pre_iteration_output=None
pre_iteration_output=None,
)
outputs: list[Any] = []
@@ -176,7 +165,9 @@ class IterationNode(BaseNode):
if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index'])
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any(
[self.node_id, "index"]
)
event.route_node_state.node_run_result.metadata = metadata
yield event
@@ -192,21 +183,15 @@ class IterationNode(BaseNode):
variable_pool.remove_node(node_id)
# move to next iteration
current_index = variable_pool.get([self.node_id, 'index'])
current_index = variable_pool.get([self.node_id, "index"])
if current_index is None:
raise ValueError(f'iteration {self.node_id} current index not found')
raise ValueError(f"iteration {self.node_id} current index not found")
next_index = int(current_index.to_object()) + 1
variable_pool.add(
[self.node_id, 'index'],
next_index
)
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add(
[self.node_id, 'item'],
iterator_list_value[next_index]
)
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
yield IterationRunNextEvent(
iteration_id=self.id,
@@ -214,8 +199,9 @@ class IterationNode(BaseNode):
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index,
pre_iteration_output=jsonable_encoder(
current_iteration_output) if current_iteration_output else None
pre_iteration_output=jsonable_encoder(current_iteration_output)
if current_iteration_output
else None,
)
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
@@ -227,13 +213,9 @@ class IterationNode(BaseNode):
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
},
outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value),
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens
},
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
@@ -255,21 +237,14 @@ class IterationNode(BaseNode):
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
},
outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value),
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens
}
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
'output': jsonable_encoder(outputs)
}
status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": jsonable_encoder(outputs)}
)
)
except Exception as e:
@@ -282,16 +257,11 @@ class IterationNode(BaseNode):
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
},
outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value),
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens
},
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=str(e),
)
yield RunCompletedEvent(
run_result=NodeRunResult(
@@ -301,15 +271,12 @@ class IterationNode(BaseNode):
)
finally:
# remove iteration variable (item, index) from variable pool after iteration run completed
variable_pool.remove([self.node_id, 'index'])
variable_pool.remove([self.node_id, 'item'])
variable_pool.remove([self.node_id, "index"])
variable_pool.remove([self.node_id, "item"])
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IterationNodeData
cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -319,36 +286,33 @@ class IterationNode(BaseNode):
:return:
"""
variable_mapping = {
f'{node_id}.input_selector': node_data.iterator_selector,
f"{node_id}.input_selector": node_data.iterator_selector,
}
# init graph
iteration_graph = Graph.init(
graph_config=graph_config,
root_node_id=node_data.start_node_id
)
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id)
if not iteration_graph:
raise ValueError('iteration graph not found')
raise ValueError("iteration graph not found")
for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items():
if sub_node_config.get('data', {}).get('iteration_id') != node_id:
if sub_node_config.get("data", {}).get("iteration_id") != node_id:
continue
# variable selector to variable mapping
try:
# Get node class
from core.workflow.nodes.node_mapping import node_classes
node_type = NodeType.value_of(sub_node_config.get('data', {}).get('type'))
node_type = NodeType.value_of(sub_node_config.get("data", {}).get("type"))
node_cls = node_classes.get(node_type)
if not node_cls:
continue
node_cls = cast(BaseNode, node_cls)
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config,
config=sub_node_config
graph_config=graph_config, config=sub_node_config
)
sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping)
except NotImplementedError:
@@ -356,7 +320,8 @@ class IterationNode(BaseNode):
# remove iteration variables
sub_node_variable_mapping = {
sub_node_id + '.' + key: value for key, value in sub_node_variable_mapping.items()
sub_node_id + "." + key: value
for key, value in sub_node_variable_mapping.items()
if value[0] != node_id
}
@@ -364,8 +329,7 @@ class IterationNode(BaseNode):
# remove variable out from iteration
variable_mapping = {
key: value for key, value in variable_mapping.items()
if value[0] not in iteration_graph.node_ids
key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids
}
return variable_mapping

View File

@@ -11,6 +11,7 @@ class IterationStartNode(BaseNode):
"""
Iteration Start Node.
"""
_node_data_cls = IterationStartNodeData
_node_type = NodeType.ITERATION_START
@@ -18,16 +19,11 @@ class IterationStartNode(BaseNode):
"""
Run the node.
"""
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED
)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IterationNodeData
cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping