Feat: upgrade variable assigner (#11285)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -38,7 +38,7 @@ from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProce
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.node_mapping import node_type_classes_mapping
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
@@ -227,7 +227,8 @@ class GraphEngine:
|
||||
|
||||
# convert to specific node
|
||||
node_type = NodeType(node_config.get("data", {}).get("type"))
|
||||
node_cls = node_type_classes_mapping[node_type]
|
||||
node_version = node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
|
||||
|
||||
|
@@ -153,7 +153,7 @@ class AnswerStreamGeneratorRouter:
|
||||
NodeType.IF_ELSE,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
NodeType.ITERATION,
|
||||
NodeType.CONVERSATION_VARIABLE_ASSIGNER,
|
||||
NodeType.VARIABLE_ASSIGNER,
|
||||
}:
|
||||
answer_dependencies[answer_node_id].append(source_node_id)
|
||||
else:
|
||||
|
@@ -7,6 +7,7 @@ from pydantic import BaseModel
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
version: str = "1"
|
||||
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
|
@@ -55,7 +55,9 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
raise ValueError("Node ID is required.")
|
||||
|
||||
self.node_id = node_id
|
||||
self.node_data: GenericNodeData = cast(GenericNodeData, self._node_data_cls(**config.get("data", {})))
|
||||
|
||||
node_data = self._node_data_cls.model_validate(config.get("data", {}))
|
||||
self.node_data = cast(GenericNodeData, node_data)
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
|
||||
|
@@ -14,11 +14,11 @@ class NodeType(StrEnum):
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
VARIABLE_AGGREGATOR = "variable-aggregator"
|
||||
VARIABLE_ASSIGNER = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
||||
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
||||
LOOP = "loop"
|
||||
ITERATION = "iteration"
|
||||
ITERATION_START = "iteration-start" # Fake start node for iteration.
|
||||
PARAMETER_EXTRACTOR = "parameter-extractor"
|
||||
CONVERSATION_VARIABLE_ASSIGNER = "assigner"
|
||||
VARIABLE_ASSIGNER = "assigner"
|
||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||
LIST_OPERATOR = "list-operator"
|
||||
|
@@ -298,12 +298,13 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
# Get node class
|
||||
from core.workflow.nodes.node_mapping import node_type_classes_mapping
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
|
||||
node_cls = node_type_classes_mapping.get(node_type)
|
||||
if not node_cls:
|
||||
if node_type not in NODE_TYPE_CLASSES_MAPPING:
|
||||
continue
|
||||
node_version = sub_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, config=sub_node_config
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.workflow.nodes.answer import AnswerNode
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.code import CodeNode
|
||||
@@ -16,26 +18,87 @@ from core.workflow.nodes.start import StartNode
|
||||
from core.workflow.nodes.template_transform import TemplateTransformNode
|
||||
from core.workflow.nodes.tool import ToolNode
|
||||
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
from core.workflow.nodes.variable_assigner import VariableAssignerNode
|
||||
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
|
||||
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2
|
||||
|
||||
node_type_classes_mapping: dict[NodeType, type[BaseNode]] = {
|
||||
NodeType.START: StartNode,
|
||||
NodeType.END: EndNode,
|
||||
NodeType.ANSWER: AnswerNode,
|
||||
NodeType.LLM: LLMNode,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
||||
NodeType.IF_ELSE: IfElseNode,
|
||||
NodeType.CODE: CodeNode,
|
||||
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
|
||||
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
|
||||
NodeType.HTTP_REQUEST: HttpRequestNode,
|
||||
NodeType.TOOL: ToolNode,
|
||||
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
|
||||
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
|
||||
NodeType.ITERATION: IterationNode,
|
||||
NodeType.ITERATION_START: IterationStartNode,
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
|
||||
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
|
||||
NodeType.DOCUMENT_EXTRACTOR: DocumentExtractorNode,
|
||||
NodeType.LIST_OPERATOR: ListOperatorNode,
|
||||
LATEST_VERSION = "latest"
|
||||
|
||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
||||
NodeType.START: {
|
||||
LATEST_VERSION: StartNode,
|
||||
"1": StartNode,
|
||||
},
|
||||
NodeType.END: {
|
||||
LATEST_VERSION: EndNode,
|
||||
"1": EndNode,
|
||||
},
|
||||
NodeType.ANSWER: {
|
||||
LATEST_VERSION: AnswerNode,
|
||||
"1": AnswerNode,
|
||||
},
|
||||
NodeType.LLM: {
|
||||
LATEST_VERSION: LLMNode,
|
||||
"1": LLMNode,
|
||||
},
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: {
|
||||
LATEST_VERSION: KnowledgeRetrievalNode,
|
||||
"1": KnowledgeRetrievalNode,
|
||||
},
|
||||
NodeType.IF_ELSE: {
|
||||
LATEST_VERSION: IfElseNode,
|
||||
"1": IfElseNode,
|
||||
},
|
||||
NodeType.CODE: {
|
||||
LATEST_VERSION: CodeNode,
|
||||
"1": CodeNode,
|
||||
},
|
||||
NodeType.TEMPLATE_TRANSFORM: {
|
||||
LATEST_VERSION: TemplateTransformNode,
|
||||
"1": TemplateTransformNode,
|
||||
},
|
||||
NodeType.QUESTION_CLASSIFIER: {
|
||||
LATEST_VERSION: QuestionClassifierNode,
|
||||
"1": QuestionClassifierNode,
|
||||
},
|
||||
NodeType.HTTP_REQUEST: {
|
||||
LATEST_VERSION: HttpRequestNode,
|
||||
"1": HttpRequestNode,
|
||||
},
|
||||
NodeType.TOOL: {
|
||||
LATEST_VERSION: ToolNode,
|
||||
"1": ToolNode,
|
||||
},
|
||||
NodeType.VARIABLE_AGGREGATOR: {
|
||||
LATEST_VERSION: VariableAggregatorNode,
|
||||
"1": VariableAggregatorNode,
|
||||
},
|
||||
NodeType.LEGACY_VARIABLE_AGGREGATOR: {
|
||||
LATEST_VERSION: VariableAggregatorNode,
|
||||
"1": VariableAggregatorNode,
|
||||
}, # original name of VARIABLE_AGGREGATOR
|
||||
NodeType.ITERATION: {
|
||||
LATEST_VERSION: IterationNode,
|
||||
"1": IterationNode,
|
||||
},
|
||||
NodeType.ITERATION_START: {
|
||||
LATEST_VERSION: IterationStartNode,
|
||||
"1": IterationStartNode,
|
||||
},
|
||||
NodeType.PARAMETER_EXTRACTOR: {
|
||||
LATEST_VERSION: ParameterExtractorNode,
|
||||
"1": ParameterExtractorNode,
|
||||
},
|
||||
NodeType.VARIABLE_ASSIGNER: {
|
||||
LATEST_VERSION: VariableAssignerNodeV2,
|
||||
"1": VariableAssignerNodeV1,
|
||||
"2": VariableAssignerNodeV2,
|
||||
},
|
||||
NodeType.DOCUMENT_EXTRACTOR: {
|
||||
LATEST_VERSION: DocumentExtractorNode,
|
||||
"1": DocumentExtractorNode,
|
||||
},
|
||||
NodeType.LIST_OPERATOR: {
|
||||
LATEST_VERSION: ListOperatorNode,
|
||||
"1": ListOperatorNode,
|
||||
},
|
||||
}
|
||||
|
@@ -1,8 +0,0 @@
|
||||
from .node import VariableAssignerNode
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
__all__ = [
|
||||
"VariableAssignerData",
|
||||
"VariableAssignerNode",
|
||||
"WriteMode",
|
||||
]
|
||||
|
4
api/core/workflow/nodes/variable_assigner/common/exc.py
Normal file
4
api/core/workflow/nodes/variable_assigner/common/exc.py
Normal file
@@ -0,0 +1,4 @@
|
||||
class VariableOperatorNodeError(Exception):
|
||||
"""Base error type, don't use directly."""
|
||||
|
||||
pass
|
19
api/core/workflow/nodes/variable_assigner/common/helpers.py
Normal file
19
api/core/workflow/nodes/variable_assigner/common/helpers.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables import Variable
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from extensions.ext_database import db
|
||||
from models import ConversationVariable
|
||||
|
||||
|
||||
def update_conversation_variable(conversation_id: str, variable: Variable):
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise VariableOperatorNodeError("conversation variable not found in the database")
|
||||
row.data = variable.model_dump_json()
|
||||
session.commit()
|
@@ -1,2 +0,0 @@
|
||||
class VariableAssignerNodeError(Exception):
|
||||
pass
|
3
api/core/workflow/nodes/variable_assigner/v1/__init__.py
Normal file
3
api/core/workflow/nodes/variable_assigner/v1/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .node import VariableAssignerNode
|
||||
|
||||
__all__ = ["VariableAssignerNode"]
|
@@ -1,40 +1,36 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode, BaseNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from factories import variable_factory
|
||||
from models import ConversationVariable
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .exc import VariableAssignerNodeError
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
_node_data_cls: type[BaseNodeData] = VariableAssignerData
|
||||
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
|
||||
_node_type = NodeType.VARIABLE_ASSIGNER
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||
original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector)
|
||||
if not isinstance(original_variable, Variable):
|
||||
raise VariableAssignerNodeError("assigned variable not found")
|
||||
raise VariableOperatorNodeError("assigned variable not found")
|
||||
|
||||
match self.node_data.write_mode:
|
||||
case WriteMode.OVER_WRITE:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableAssignerNodeError("input value not found")
|
||||
raise VariableOperatorNodeError("input value not found")
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.value})
|
||||
|
||||
case WriteMode.APPEND:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableAssignerNodeError("input value not found")
|
||||
raise VariableOperatorNodeError("input value not found")
|
||||
updated_value = original_variable.value + [income_value.value]
|
||||
updated_variable = original_variable.model_copy(update={"value": updated_value})
|
||||
|
||||
@@ -43,7 +39,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
|
||||
|
||||
case _:
|
||||
raise VariableAssignerNodeError(f"unsupported write mode: {self.node_data.write_mode}")
|
||||
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
|
||||
|
||||
# Over write the variable.
|
||||
self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable)
|
||||
@@ -52,8 +48,8 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
# Update conversation variable.
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
||||
if not conversation_id:
|
||||
raise VariableAssignerNodeError("conversation_id not found")
|
||||
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
|
||||
raise VariableOperatorNodeError("conversation_id not found")
|
||||
common_helpers.update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@@ -63,18 +59,6 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
)
|
||||
|
||||
|
||||
def update_conversation_variable(conversation_id: str, variable: Variable):
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise VariableAssignerNodeError("conversation variable not found in the database")
|
||||
row.data = variable.model_dump_json()
|
||||
session.commit()
|
||||
|
||||
|
||||
def get_zero_value(t: SegmentType):
|
||||
match t:
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
|
||||
@@ -86,4 +70,4 @@ def get_zero_value(t: SegmentType):
|
||||
case SegmentType.NUMBER:
|
||||
return variable_factory.build_segment(0)
|
||||
case _:
|
||||
raise VariableAssignerNodeError(f"unsupported variable type: {t}")
|
||||
raise VariableOperatorNodeError(f"unsupported variable type: {t}")
|
@@ -1,6 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
@@ -12,8 +11,6 @@ class WriteMode(StrEnum):
|
||||
|
||||
|
||||
class VariableAssignerData(BaseNodeData):
|
||||
title: str = "Variable Assigner"
|
||||
desc: Optional[str] = "Assign a value to a variable"
|
||||
assigned_variable_selector: Sequence[str]
|
||||
write_mode: WriteMode
|
||||
input_variable_selector: Sequence[str]
|
3
api/core/workflow/nodes/variable_assigner/v2/__init__.py
Normal file
3
api/core/workflow/nodes/variable_assigner/v2/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .node import VariableAssignerNode
|
||||
|
||||
__all__ = ["VariableAssignerNode"]
|
11
api/core/workflow/nodes/variable_assigner/v2/constants.py
Normal file
11
api/core/workflow/nodes/variable_assigner/v2/constants.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from core.variables import SegmentType
|
||||
|
||||
EMPTY_VALUE_MAPPING = {
|
||||
SegmentType.STRING: "",
|
||||
SegmentType.NUMBER: 0,
|
||||
SegmentType.OBJECT: {},
|
||||
SegmentType.ARRAY_ANY: [],
|
||||
SegmentType.ARRAY_STRING: [],
|
||||
SegmentType.ARRAY_NUMBER: [],
|
||||
SegmentType.ARRAY_OBJECT: [],
|
||||
}
|
20
api/core/workflow/nodes/variable_assigner/v2/entities.py
Normal file
20
api/core/workflow/nodes/variable_assigner/v2/entities.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
from .enums import InputType, Operation
|
||||
|
||||
|
||||
class VariableOperationItem(BaseModel):
|
||||
variable_selector: Sequence[str]
|
||||
input_type: InputType
|
||||
operation: Operation
|
||||
value: Any | None = None
|
||||
|
||||
|
||||
class VariableAssignerNodeData(BaseNodeData):
|
||||
version: str = "2"
|
||||
items: Sequence[VariableOperationItem]
|
18
api/core/workflow/nodes/variable_assigner/v2/enums.py
Normal file
18
api/core/workflow/nodes/variable_assigner/v2/enums.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class Operation(StrEnum):
|
||||
OVER_WRITE = "over-write"
|
||||
CLEAR = "clear"
|
||||
APPEND = "append"
|
||||
EXTEND = "extend"
|
||||
SET = "set"
|
||||
ADD = "+="
|
||||
SUBTRACT = "-="
|
||||
MULTIPLY = "*="
|
||||
DIVIDE = "/="
|
||||
|
||||
|
||||
class InputType(StrEnum):
|
||||
VARIABLE = "variable"
|
||||
CONSTANT = "constant"
|
31
api/core/workflow/nodes/variable_assigner/v2/exc.py
Normal file
31
api/core/workflow/nodes/variable_assigner/v2/exc.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
|
||||
from .enums import InputType, Operation
|
||||
|
||||
|
||||
class OperationNotSupportedError(VariableOperatorNodeError):
|
||||
def __init__(self, *, operation: Operation, varialbe_type: str):
|
||||
super().__init__(f"Operation {operation} is not supported for type {varialbe_type}")
|
||||
|
||||
|
||||
class InputTypeNotSupportedError(VariableOperatorNodeError):
|
||||
def __init__(self, *, input_type: InputType, operation: Operation):
|
||||
super().__init__(f"Input type {input_type} is not supported for operation {operation}")
|
||||
|
||||
|
||||
class VariableNotFoundError(VariableOperatorNodeError):
|
||||
def __init__(self, *, variable_selector: Sequence[str]):
|
||||
super().__init__(f"Variable {variable_selector} not found")
|
||||
|
||||
|
||||
class InvalidInputValueError(VariableOperatorNodeError):
|
||||
def __init__(self, *, value: Any):
|
||||
super().__init__(f"Invalid input value {value}")
|
||||
|
||||
|
||||
class ConversationIDNotFoundError(VariableOperatorNodeError):
|
||||
def __init__(self):
|
||||
super().__init__("conversation_id not found")
|
91
api/core/workflow/nodes/variable_assigner/v2/helpers.py
Normal file
91
api/core/workflow/nodes/variable_assigner/v2/helpers.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from typing import Any
|
||||
|
||||
from core.variables import SegmentType
|
||||
|
||||
from .enums import Operation
|
||||
|
||||
|
||||
def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
|
||||
match operation:
|
||||
case Operation.OVER_WRITE | Operation.CLEAR:
|
||||
return True
|
||||
case Operation.SET:
|
||||
return variable_type in {SegmentType.OBJECT, SegmentType.STRING, SegmentType.NUMBER}
|
||||
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
|
||||
# Only number variable can be added, subtracted, multiplied or divided
|
||||
return variable_type == SegmentType.NUMBER
|
||||
case Operation.APPEND | Operation.EXTEND:
|
||||
# Only array variable can be appended or extended
|
||||
return variable_type in {
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_FILE,
|
||||
}
|
||||
case _:
|
||||
return False
|
||||
|
||||
|
||||
def is_variable_input_supported(*, operation: Operation):
|
||||
if operation in {Operation.SET, Operation.ADD, Operation.SUBTRACT, Operation.MULTIPLY, Operation.DIVIDE}:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation):
|
||||
match variable_type:
|
||||
case SegmentType.STRING | SegmentType.OBJECT:
|
||||
return operation in {Operation.OVER_WRITE, Operation.SET}
|
||||
case SegmentType.NUMBER:
|
||||
return operation in {
|
||||
Operation.OVER_WRITE,
|
||||
Operation.SET,
|
||||
Operation.ADD,
|
||||
Operation.SUBTRACT,
|
||||
Operation.MULTIPLY,
|
||||
Operation.DIVIDE,
|
||||
}
|
||||
case _:
|
||||
return False
|
||||
|
||||
|
||||
def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, value: Any):
|
||||
if operation == Operation.CLEAR:
|
||||
return True
|
||||
match variable_type:
|
||||
case SegmentType.STRING:
|
||||
return isinstance(value, str)
|
||||
|
||||
case SegmentType.NUMBER:
|
||||
if not isinstance(value, int | float):
|
||||
return False
|
||||
if operation == Operation.DIVIDE and value == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
case SegmentType.OBJECT:
|
||||
return isinstance(value, dict)
|
||||
|
||||
# Array & Append
|
||||
case SegmentType.ARRAY_ANY if operation == Operation.APPEND:
|
||||
return isinstance(value, str | float | int | dict)
|
||||
case SegmentType.ARRAY_STRING if operation == Operation.APPEND:
|
||||
return isinstance(value, str)
|
||||
case SegmentType.ARRAY_NUMBER if operation == Operation.APPEND:
|
||||
return isinstance(value, int | float)
|
||||
case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND:
|
||||
return isinstance(value, dict)
|
||||
|
||||
# Array & Extend / Overwrite
|
||||
case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, str | float | int | dict) for item in value)
|
||||
case SegmentType.ARRAY_STRING if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, str) for item in value)
|
||||
case SegmentType.ARRAY_NUMBER if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, int | float) for item in value)
|
||||
case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, dict) for item in value)
|
||||
|
||||
case _:
|
||||
return False
|
159
api/core/workflow/nodes/variable_assigner/v2/node.py
Normal file
159
api/core/workflow/nodes/variable_assigner/v2/node.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from . import helpers
|
||||
from .constants import EMPTY_VALUE_MAPPING
|
||||
from .entities import VariableAssignerNodeData
|
||||
from .enums import InputType, Operation
|
||||
from .exc import (
|
||||
ConversationIDNotFoundError,
|
||||
InputTypeNotSupportedError,
|
||||
InvalidInputValueError,
|
||||
OperationNotSupportedError,
|
||||
VariableNotFoundError,
|
||||
)
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
||||
_node_data_cls = VariableAssignerNodeData
|
||||
_node_type = NodeType.VARIABLE_ASSIGNER
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
inputs = self.node_data.model_dump()
|
||||
process_data = {}
|
||||
# NOTE: This node has no outputs
|
||||
updated_variables: list[Variable] = []
|
||||
|
||||
try:
|
||||
for item in self.node_data.items:
|
||||
variable = self.graph_runtime_state.variable_pool.get(item.variable_selector)
|
||||
|
||||
# ==================== Validation Part
|
||||
|
||||
# Check if variable exists
|
||||
if not isinstance(variable, Variable):
|
||||
raise VariableNotFoundError(variable_selector=item.variable_selector)
|
||||
|
||||
# Check if operation is supported
|
||||
if not helpers.is_operation_supported(variable_type=variable.value_type, operation=item.operation):
|
||||
raise OperationNotSupportedError(operation=item.operation, varialbe_type=variable.value_type)
|
||||
|
||||
# Check if variable input is supported
|
||||
if item.input_type == InputType.VARIABLE and not helpers.is_variable_input_supported(
|
||||
operation=item.operation
|
||||
):
|
||||
raise InputTypeNotSupportedError(input_type=InputType.VARIABLE, operation=item.operation)
|
||||
|
||||
# Check if constant input is supported
|
||||
if item.input_type == InputType.CONSTANT and not helpers.is_constant_input_supported(
|
||||
variable_type=variable.value_type, operation=item.operation
|
||||
):
|
||||
raise InputTypeNotSupportedError(input_type=InputType.CONSTANT, operation=item.operation)
|
||||
|
||||
# Get value from variable pool
|
||||
if (
|
||||
item.input_type == InputType.VARIABLE
|
||||
and item.operation != Operation.CLEAR
|
||||
and item.value is not None
|
||||
):
|
||||
value = self.graph_runtime_state.variable_pool.get(item.value)
|
||||
if value is None:
|
||||
raise VariableNotFoundError(variable_selector=item.value)
|
||||
# Skip if value is NoneSegment
|
||||
if value.value_type == SegmentType.NONE:
|
||||
continue
|
||||
item.value = value.value
|
||||
|
||||
# If set string / bytes / bytearray to object, try convert string to object.
|
||||
if (
|
||||
item.operation == Operation.SET
|
||||
and variable.value_type == SegmentType.OBJECT
|
||||
and isinstance(item.value, str | bytes | bytearray)
|
||||
):
|
||||
try:
|
||||
item.value = json.loads(item.value)
|
||||
except json.JSONDecodeError:
|
||||
raise InvalidInputValueError(value=item.value)
|
||||
|
||||
# Check if input value is valid
|
||||
if not helpers.is_input_value_valid(
|
||||
variable_type=variable.value_type, operation=item.operation, value=item.value
|
||||
):
|
||||
raise InvalidInputValueError(value=item.value)
|
||||
|
||||
# ==================== Execution Part
|
||||
|
||||
updated_value = self._handle_item(
|
||||
variable=variable,
|
||||
operation=item.operation,
|
||||
value=item.value,
|
||||
)
|
||||
variable = variable.model_copy(update={"value": updated_value})
|
||||
updated_variables.append(variable)
|
||||
except VariableOperatorNodeError as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
# Update variables
|
||||
for variable in updated_variables:
|
||||
self.graph_runtime_state.variable_pool.add(variable.selector, variable)
|
||||
process_data[variable.name] = variable.value
|
||||
|
||||
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
||||
if not conversation_id:
|
||||
raise ConversationIDNotFoundError
|
||||
else:
|
||||
conversation_id = conversation_id.value
|
||||
common_helpers.update_conversation_variable(
|
||||
conversation_id=conversation_id,
|
||||
variable=variable,
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
)
|
||||
|
||||
def _handle_item(
|
||||
self,
|
||||
*,
|
||||
variable: Variable,
|
||||
operation: Operation,
|
||||
value: Any,
|
||||
):
|
||||
match operation:
|
||||
case Operation.OVER_WRITE:
|
||||
return value
|
||||
case Operation.CLEAR:
|
||||
return EMPTY_VALUE_MAPPING[variable.value_type]
|
||||
case Operation.APPEND:
|
||||
return variable.value + [value]
|
||||
case Operation.EXTEND:
|
||||
return variable.value + value
|
||||
case Operation.SET:
|
||||
return value
|
||||
case Operation.ADD:
|
||||
return variable.value + value
|
||||
case Operation.SUBTRACT:
|
||||
return variable.value - value
|
||||
case Operation.MULTIPLY:
|
||||
return variable.value * value
|
||||
case Operation.DIVIDE:
|
||||
return variable.value / value
|
||||
case _:
|
||||
raise OperationNotSupportedError(operation=operation, varialbe_type=variable.value_type)
|
@@ -2,7 +2,7 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||
@@ -19,7 +19,7 @@ from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.event import NodeEvent
|
||||
from core.workflow.nodes.node_mapping import node_type_classes_mapping
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from factories import file_factory
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import (
|
||||
@@ -145,11 +145,8 @@ class WorkflowEntry:
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(node_config.get("data", {}).get("type"))
|
||||
node_cls = node_type_classes_mapping.get(node_type)
|
||||
node_cls = cast(type[BaseNode], node_cls)
|
||||
|
||||
if not node_cls:
|
||||
raise ValueError(f"Node class not found for node type {node_type}")
|
||||
node_version = node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(environment_variables=workflow.environment_variables)
|
||||
|
Reference in New Issue
Block a user