feat/enhance the multi-modal support (#8818)

This commit is contained in:
-LAN-
2024-10-21 10:43:49 +08:00
committed by GitHub
parent 7a1d6fe509
commit e61752bd3a
267 changed files with 6263 additions and 3523 deletions

View File

@@ -0,0 +1,5 @@
from .entities import IterationNodeData
from .iteration_node import IterationNode
from .iteration_start_node import IterationStartNode
__all__ = ["IterationNode", "IterationNodeData", "IterationStartNode"]

View File

@@ -1,6 +1,8 @@
from typing import Any, Optional
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
from pydantic import Field
from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData
class IterationNodeData(BaseIterationNodeData):
@@ -26,7 +28,7 @@ class IterationState(BaseIterationState):
Iteration State.
"""
outputs: list[Any] = None
outputs: list[Any] = Field(default_factory=list)
current_output: Optional[Any] = None
class MetaData(BaseIterationState.MetaData):

View File

@@ -5,7 +5,7 @@ from typing import Any, cast
from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.graph_engine.entities.event import (
BaseGraphEvent,
BaseNodeEvent,
@@ -20,15 +20,16 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.iteration.entities import IterationNodeData
from models.workflow import WorkflowNodeExecutionStatus
logger = logging.getLogger(__name__)
class IterationNode(BaseNode):
class IterationNode(BaseNode[IterationNodeData]):
"""
Iteration Node.
"""
@@ -36,11 +37,10 @@ class IterationNode(BaseNode):
_node_data_cls = IterationNodeData
_node_type = NodeType.ITERATION
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
Run the node.
"""
self.node_data = cast(IterationNodeData, self.node_data)
iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
if not iterator_list_segment:
@@ -177,7 +177,7 @@ class IterationNode(BaseNode):
# remove all nodes outputs from variable pool
for node_id in iteration_graph.node_ids:
variable_pool.remove_node(node_id)
variable_pool.remove([node_id])
# move to next iteration
current_index = variable_pool.get([self.node_id, "index"])
@@ -247,7 +247,11 @@ class IterationNode(BaseNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IterationNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -273,15 +277,13 @@ class IterationNode(BaseNode):
# variable selector to variable mapping
try:
# Get node class
from core.workflow.nodes.node_mapping import node_classes
from core.workflow.nodes.node_mapping import node_type_classes_mapping
node_type = NodeType.value_of(sub_node_config.get("data", {}).get("type"))
node_cls = node_classes.get(node_type)
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
node_cls = node_type_classes_mapping.get(node_type)
if not node_cls:
continue
node_cls = cast(BaseNode, node_cls)
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config, config=sub_node_config
)

View File

@@ -1,8 +1,9 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData
from models.workflow import WorkflowNodeExecutionStatus