feat/enhance the multi-modal support (#8818)
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
from .answer_node import AnswerNode
|
||||
from .entities import AnswerStreamGenerateRoute
|
||||
|
||||
__all__ = ["AnswerStreamGenerateRoute", "AnswerNode"]
|
||||
|
@@ -1,7 +1,8 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.variables import ArrayFileSegment, FileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
@@ -9,12 +10,13 @@ from core.workflow.nodes.answer.entities import (
|
||||
TextGenerateRouteChunk,
|
||||
VarGenerateRouteChunk,
|
||||
)
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class AnswerNode(BaseNode):
|
||||
class AnswerNode(BaseNode[AnswerNodeData]):
|
||||
_node_data_cls = AnswerNodeData
|
||||
_node_type: NodeType = NodeType.ANSWER
|
||||
|
||||
@@ -23,30 +25,35 @@ class AnswerNode(BaseNode):
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(AnswerNodeData, node_data)
|
||||
|
||||
# generate routes
|
||||
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data)
|
||||
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self.node_data)
|
||||
|
||||
answer = ""
|
||||
files = []
|
||||
for part in generate_routes:
|
||||
if part.type == GenerateRouteChunk.ChunkType.VAR:
|
||||
part = cast(VarGenerateRouteChunk, part)
|
||||
value_selector = part.value_selector
|
||||
value = self.graph_runtime_state.variable_pool.get(value_selector)
|
||||
|
||||
if value:
|
||||
answer += value.markdown
|
||||
variable = self.graph_runtime_state.variable_pool.get(value_selector)
|
||||
if variable:
|
||||
if isinstance(variable, FileSegment):
|
||||
files.append(variable.value)
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
files.extend(variable.value)
|
||||
answer += variable.markdown
|
||||
else:
|
||||
part = cast(TextGenerateRouteChunk, part)
|
||||
answer += part.text
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer})
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer, "files": files})
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: AnswerNodeData
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: AnswerNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -55,9 +62,6 @@ class AnswerNode(BaseNode):
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
node_data = node_data
|
||||
node_data = cast(AnswerNodeData, node_data)
|
||||
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
|
@@ -1,5 +1,4 @@
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
AnswerStreamGenerateRoute,
|
||||
@@ -7,6 +6,7 @@ from core.workflow.nodes.answer.entities import (
|
||||
TextGenerateRouteChunk,
|
||||
VarGenerateRouteChunk,
|
||||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
|
@@ -1,8 +1,8 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, cast
|
||||
from typing import cast
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
@@ -203,7 +203,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
return files
|
||||
|
||||
@classmethod
|
||||
def _get_file_var_from_value(cls, value: dict | list) -> Optional[dict]:
|
||||
def _get_file_var_from_value(cls, value: dict | list):
|
||||
"""
|
||||
Get file var from value
|
||||
:param value: variable value
|
||||
@@ -213,9 +213,9 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
return None
|
||||
|
||||
if isinstance(value, dict):
|
||||
if "__variant" in value and value["__variant"] == FileVar.__name__:
|
||||
if "dify_model_identity" in value and value["dify_model_identity"] == FILE_MODEL_IDENTITY:
|
||||
return value
|
||||
elif isinstance(value, FileVar):
|
||||
elif isinstance(value, File):
|
||||
return value.to_dict()
|
||||
|
||||
return None
|
||||
|
@@ -2,7 +2,7 @@ from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class AnswerNodeData(BaseNodeData):
|
||||
|
Reference in New Issue
Block a user