chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -7,21 +7,25 @@ class IterationNodeData(BaseIterationNodeData):
|
||||
"""
|
||||
Iteration Node Data.
|
||||
"""
|
||||
parent_loop_id: Optional[str] = None # redundant field, not used currently
|
||||
iterator_selector: list[str] # variable selector
|
||||
output_selector: list[str] # output selector
|
||||
|
||||
parent_loop_id: Optional[str] = None # redundant field, not used currently
|
||||
iterator_selector: list[str] # variable selector
|
||||
output_selector: list[str] # output selector
|
||||
|
||||
|
||||
class IterationStartNodeData(BaseNodeData):
|
||||
"""
|
||||
Iteration Start Node Data.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class IterationState(BaseIterationState):
|
||||
"""
|
||||
Iteration State.
|
||||
"""
|
||||
|
||||
outputs: list[Any] = None
|
||||
current_output: Optional[Any] = None
|
||||
|
||||
@@ -29,6 +33,7 @@ class IterationState(BaseIterationState):
|
||||
"""
|
||||
Data.
|
||||
"""
|
||||
|
||||
iterator_length: int
|
||||
|
||||
def get_last_output(self) -> Optional[Any]:
|
||||
@@ -38,9 +43,9 @@ class IterationState(BaseIterationState):
|
||||
if self.outputs:
|
||||
return self.outputs[-1]
|
||||
return None
|
||||
|
||||
|
||||
def get_current_output(self) -> Optional[Any]:
|
||||
"""
|
||||
Get current output.
|
||||
"""
|
||||
return self.current_output
|
||||
return self.current_output
|
||||
|
@@ -33,6 +33,7 @@ class IterationNode(BaseNode):
|
||||
"""
|
||||
Iteration Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = IterationNodeData
|
||||
_node_type = NodeType.ITERATION
|
||||
|
||||
@@ -45,31 +46,26 @@ class IterationNode(BaseNode):
|
||||
|
||||
if not iterator_list_segment:
|
||||
raise ValueError(f"Iterator variable {self.node_data.iterator_selector} not found")
|
||||
|
||||
|
||||
iterator_list_value = iterator_list_segment.to_object()
|
||||
|
||||
if not isinstance(iterator_list_value, list):
|
||||
raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
|
||||
|
||||
inputs = {
|
||||
"iterator_selector": iterator_list_value
|
||||
}
|
||||
inputs = {"iterator_selector": iterator_list_value}
|
||||
|
||||
graph_config = self.graph_config
|
||||
|
||||
|
||||
if not self.node_data.start_node_id:
|
||||
raise ValueError(f'field start_node_id in iteration {self.node_id} not found')
|
||||
raise ValueError(f"field start_node_id in iteration {self.node_id} not found")
|
||||
|
||||
root_node_id = self.node_data.start_node_id
|
||||
|
||||
# init graph
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
root_node_id=root_node_id
|
||||
)
|
||||
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
|
||||
|
||||
if not iteration_graph:
|
||||
raise ValueError('iteration graph not found')
|
||||
raise ValueError("iteration graph not found")
|
||||
|
||||
leaf_node_ids = iteration_graph.get_leaf_node_ids()
|
||||
iteration_leaf_node_ids = []
|
||||
@@ -97,26 +93,21 @@ class IterationNode(BaseNode):
|
||||
Condition(
|
||||
variable_selector=[self.node_id, "index"],
|
||||
comparison_operator="<",
|
||||
value=str(len(iterator_list_value))
|
||||
value=str(len(iterator_list_value)),
|
||||
)
|
||||
]
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# append iteration variable (item, index) to variable pool
|
||||
variable_pool.add(
|
||||
[self.node_id, 'index'],
|
||||
0
|
||||
)
|
||||
variable_pool.add(
|
||||
[self.node_id, 'item'],
|
||||
iterator_list_value[0]
|
||||
)
|
||||
variable_pool.add([self.node_id, "index"], 0)
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
|
||||
|
||||
# init graph engine
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
@@ -130,7 +121,7 @@ class IterationNode(BaseNode):
|
||||
graph_config=graph_config,
|
||||
variable_pool=variable_pool,
|
||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
)
|
||||
|
||||
start_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
@@ -142,10 +133,8 @@ class IterationNode(BaseNode):
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
metadata={
|
||||
"iterator_length": len(iterator_list_value)
|
||||
},
|
||||
predecessor_node_id=self.previous_node_id
|
||||
metadata={"iterator_length": len(iterator_list_value)},
|
||||
predecessor_node_id=self.previous_node_id,
|
||||
)
|
||||
|
||||
yield IterationRunNextEvent(
|
||||
@@ -154,7 +143,7 @@ class IterationNode(BaseNode):
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
index=0,
|
||||
pre_iteration_output=None
|
||||
pre_iteration_output=None,
|
||||
)
|
||||
|
||||
outputs: list[Any] = []
|
||||
@@ -176,7 +165,9 @@ class IterationNode(BaseNode):
|
||||
|
||||
if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
||||
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
||||
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index'])
|
||||
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any(
|
||||
[self.node_id, "index"]
|
||||
)
|
||||
event.route_node_state.node_run_result.metadata = metadata
|
||||
|
||||
yield event
|
||||
@@ -192,21 +183,15 @@ class IterationNode(BaseNode):
|
||||
variable_pool.remove_node(node_id)
|
||||
|
||||
# move to next iteration
|
||||
current_index = variable_pool.get([self.node_id, 'index'])
|
||||
current_index = variable_pool.get([self.node_id, "index"])
|
||||
if current_index is None:
|
||||
raise ValueError(f'iteration {self.node_id} current index not found')
|
||||
raise ValueError(f"iteration {self.node_id} current index not found")
|
||||
|
||||
next_index = int(current_index.to_object()) + 1
|
||||
variable_pool.add(
|
||||
[self.node_id, 'index'],
|
||||
next_index
|
||||
)
|
||||
variable_pool.add([self.node_id, "index"], next_index)
|
||||
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add(
|
||||
[self.node_id, 'item'],
|
||||
iterator_list_value[next_index]
|
||||
)
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
@@ -214,8 +199,9 @@ class IterationNode(BaseNode):
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
index=next_index,
|
||||
pre_iteration_output=jsonable_encoder(
|
||||
current_iteration_output) if current_iteration_output else None
|
||||
pre_iteration_output=jsonable_encoder(current_iteration_output)
|
||||
if current_iteration_output
|
||||
else None,
|
||||
)
|
||||
elif isinstance(event, BaseGraphEvent):
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
@@ -227,13 +213,9 @@ class IterationNode(BaseNode):
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={
|
||||
"output": jsonable_encoder(outputs)
|
||||
},
|
||||
outputs={"output": jsonable_encoder(outputs)},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens
|
||||
},
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
@@ -255,21 +237,14 @@ class IterationNode(BaseNode):
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={
|
||||
"output": jsonable_encoder(outputs)
|
||||
},
|
||||
outputs={"output": jsonable_encoder(outputs)},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens
|
||||
}
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
'output': jsonable_encoder(outputs)
|
||||
}
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": jsonable_encoder(outputs)}
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -282,16 +257,11 @@ class IterationNode(BaseNode):
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={
|
||||
"output": jsonable_encoder(outputs)
|
||||
},
|
||||
outputs={"output": jsonable_encoder(outputs)},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens
|
||||
},
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
@@ -301,15 +271,12 @@ class IterationNode(BaseNode):
|
||||
)
|
||||
finally:
|
||||
# remove iteration variable (item, index) from variable pool after iteration run completed
|
||||
variable_pool.remove([self.node_id, 'index'])
|
||||
variable_pool.remove([self.node_id, 'item'])
|
||||
|
||||
variable_pool.remove([self.node_id, "index"])
|
||||
variable_pool.remove([self.node_id, "item"])
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IterationNodeData
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -319,36 +286,33 @@ class IterationNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
variable_mapping = {
|
||||
f'{node_id}.input_selector': node_data.iterator_selector,
|
||||
f"{node_id}.input_selector": node_data.iterator_selector,
|
||||
}
|
||||
|
||||
# init graph
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
root_node_id=node_data.start_node_id
|
||||
)
|
||||
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id)
|
||||
|
||||
if not iteration_graph:
|
||||
raise ValueError('iteration graph not found')
|
||||
|
||||
raise ValueError("iteration graph not found")
|
||||
|
||||
for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items():
|
||||
if sub_node_config.get('data', {}).get('iteration_id') != node_id:
|
||||
if sub_node_config.get("data", {}).get("iteration_id") != node_id:
|
||||
continue
|
||||
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
# Get node class
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
node_type = NodeType.value_of(sub_node_config.get('data', {}).get('type'))
|
||||
|
||||
node_type = NodeType.value_of(sub_node_config.get("data", {}).get("type"))
|
||||
node_cls = node_classes.get(node_type)
|
||||
if not node_cls:
|
||||
continue
|
||||
|
||||
node_cls = cast(BaseNode, node_cls)
|
||||
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config,
|
||||
config=sub_node_config
|
||||
graph_config=graph_config, config=sub_node_config
|
||||
)
|
||||
sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping)
|
||||
except NotImplementedError:
|
||||
@@ -356,7 +320,8 @@ class IterationNode(BaseNode):
|
||||
|
||||
# remove iteration variables
|
||||
sub_node_variable_mapping = {
|
||||
sub_node_id + '.' + key: value for key, value in sub_node_variable_mapping.items()
|
||||
sub_node_id + "." + key: value
|
||||
for key, value in sub_node_variable_mapping.items()
|
||||
if value[0] != node_id
|
||||
}
|
||||
|
||||
@@ -364,8 +329,7 @@ class IterationNode(BaseNode):
|
||||
|
||||
# remove variable out from iteration
|
||||
variable_mapping = {
|
||||
key: value for key, value in variable_mapping.items()
|
||||
if value[0] not in iteration_graph.node_ids
|
||||
key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids
|
||||
}
|
||||
|
||||
|
||||
return variable_mapping
|
||||
|
@@ -11,6 +11,7 @@ class IterationStartNode(BaseNode):
|
||||
"""
|
||||
Iteration Start Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = IterationStartNodeData
|
||||
_node_type = NodeType.ITERATION_START
|
||||
|
||||
@@ -18,16 +19,11 @@ class IterationStartNode(BaseNode):
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
)
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IterationNodeData
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
|
Reference in New Issue
Block a user