feat/enhance the multi-modal support (#8818)
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
from .entities import IterationNodeData
|
||||
from .iteration_node import IterationNode
|
||||
from .iteration_start_node import IterationStartNode
|
||||
|
||||
__all__ = ["IterationNode", "IterationNodeData", "IterationStartNode"]
|
||||
|
@@ -1,6 +1,8 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
|
||||
from pydantic import Field
|
||||
|
||||
from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData
|
||||
|
||||
|
||||
class IterationNodeData(BaseIterationNodeData):
|
||||
@@ -26,7 +28,7 @@ class IterationState(BaseIterationState):
|
||||
Iteration State.
|
||||
"""
|
||||
|
||||
outputs: list[Any] = None
|
||||
outputs: list[Any] = Field(default_factory=list)
|
||||
current_output: Optional[Any] = None
|
||||
|
||||
class MetaData(BaseIterationState.MetaData):
|
||||
|
@@ -5,7 +5,7 @@ from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseGraphEvent,
|
||||
BaseNodeEvent,
|
||||
@@ -20,15 +20,16 @@ from core.workflow.graph_engine.entities.event import (
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IterationNode(BaseNode):
|
||||
class IterationNode(BaseNode[IterationNodeData]):
|
||||
"""
|
||||
Iteration Node.
|
||||
"""
|
||||
@@ -36,11 +37,10 @@ class IterationNode(BaseNode):
|
||||
_node_data_cls = IterationNodeData
|
||||
_node_type = NodeType.ITERATION
|
||||
|
||||
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
self.node_data = cast(IterationNodeData, self.node_data)
|
||||
iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
|
||||
|
||||
if not iterator_list_segment:
|
||||
@@ -177,7 +177,7 @@ class IterationNode(BaseNode):
|
||||
|
||||
# remove all nodes outputs from variable pool
|
||||
for node_id in iteration_graph.node_ids:
|
||||
variable_pool.remove_node(node_id)
|
||||
variable_pool.remove([node_id])
|
||||
|
||||
# move to next iteration
|
||||
current_index = variable_pool.get([self.node_id, "index"])
|
||||
@@ -247,7 +247,11 @@ class IterationNode(BaseNode):
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IterationNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -273,15 +277,13 @@ class IterationNode(BaseNode):
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
# Get node class
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
from core.workflow.nodes.node_mapping import node_type_classes_mapping
|
||||
|
||||
node_type = NodeType.value_of(sub_node_config.get("data", {}).get("type"))
|
||||
node_cls = node_classes.get(node_type)
|
||||
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
|
||||
node_cls = node_type_classes_mapping.get(node_type)
|
||||
if not node_cls:
|
||||
continue
|
||||
|
||||
node_cls = cast(BaseNode, node_cls)
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, config=sub_node_config
|
||||
)
|
||||
|
@@ -1,8 +1,9 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
Reference in New Issue
Block a user