Feat: upgrade variable assigner (#11285)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
Yi Xiao
2024-12-03 13:56:40 +08:00
committed by GitHub
parent e79eac688a
commit e135ffc2c1
62 changed files with 1565 additions and 301 deletions

View File

@@ -38,7 +38,7 @@ from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProce
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
@@ -227,7 +227,8 @@ class GraphEngine:
# convert to specific node
node_type = NodeType(node_config.get("data", {}).get("type"))
node_cls = node_type_classes_mapping[node_type]
node_version = node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None

View File

@@ -153,7 +153,7 @@ class AnswerStreamGeneratorRouter:
NodeType.IF_ELSE,
NodeType.QUESTION_CLASSIFIER,
NodeType.ITERATION,
NodeType.CONVERSATION_VARIABLE_ASSIGNER,
NodeType.VARIABLE_ASSIGNER,
}:
answer_dependencies[answer_node_id].append(source_node_id)
else:

View File

@@ -7,6 +7,7 @@ from pydantic import BaseModel
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
version: str = "1"
class BaseIterationNodeData(BaseNodeData):

View File

@@ -55,7 +55,9 @@ class BaseNode(Generic[GenericNodeData]):
raise ValueError("Node ID is required.")
self.node_id = node_id
self.node_data: GenericNodeData = cast(GenericNodeData, self._node_data_cls(**config.get("data", {})))
node_data = self._node_data_cls.model_validate(config.get("data", {}))
self.node_data = cast(GenericNodeData, node_data)
@abstractmethod
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:

View File

@@ -14,11 +14,11 @@ class NodeType(StrEnum):
HTTP_REQUEST = "http-request"
TOOL = "tool"
VARIABLE_AGGREGATOR = "variable-aggregator"
VARIABLE_ASSIGNER = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LOOP = "loop"
ITERATION = "iteration"
ITERATION_START = "iteration-start" # Fake start node for iteration.
PARAMETER_EXTRACTOR = "parameter-extractor"
CONVERSATION_VARIABLE_ASSIGNER = "assigner"
VARIABLE_ASSIGNER = "assigner"
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"

View File

@@ -298,12 +298,13 @@ class IterationNode(BaseNode[IterationNodeData]):
# variable selector to variable mapping
try:
# Get node class
from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
node_cls = node_type_classes_mapping.get(node_type)
if not node_cls:
if node_type not in NODE_TYPE_CLASSES_MAPPING:
continue
node_version = sub_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config, config=sub_node_config

View File

@@ -1,3 +1,5 @@
from collections.abc import Mapping
from core.workflow.nodes.answer import AnswerNode
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.code import CodeNode
@@ -16,26 +18,87 @@ from core.workflow.nodes.start import StartNode
from core.workflow.nodes.template_transform import TemplateTransformNode
from core.workflow.nodes.tool import ToolNode
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from core.workflow.nodes.variable_assigner import VariableAssignerNode
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2
node_type_classes_mapping: dict[NodeType, type[BaseNode]] = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.ANSWER: AnswerNode,
NodeType.LLM: LLMNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.IF_ELSE: IfElseNode,
NodeType.CODE: CodeNode,
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.TOOL: ToolNode,
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: IterationNode,
NodeType.ITERATION_START: IterationStartNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
NodeType.DOCUMENT_EXTRACTOR: DocumentExtractorNode,
NodeType.LIST_OPERATOR: ListOperatorNode,
LATEST_VERSION = "latest"
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
NodeType.START: {
LATEST_VERSION: StartNode,
"1": StartNode,
},
NodeType.END: {
LATEST_VERSION: EndNode,
"1": EndNode,
},
NodeType.ANSWER: {
LATEST_VERSION: AnswerNode,
"1": AnswerNode,
},
NodeType.LLM: {
LATEST_VERSION: LLMNode,
"1": LLMNode,
},
NodeType.KNOWLEDGE_RETRIEVAL: {
LATEST_VERSION: KnowledgeRetrievalNode,
"1": KnowledgeRetrievalNode,
},
NodeType.IF_ELSE: {
LATEST_VERSION: IfElseNode,
"1": IfElseNode,
},
NodeType.CODE: {
LATEST_VERSION: CodeNode,
"1": CodeNode,
},
NodeType.TEMPLATE_TRANSFORM: {
LATEST_VERSION: TemplateTransformNode,
"1": TemplateTransformNode,
},
NodeType.QUESTION_CLASSIFIER: {
LATEST_VERSION: QuestionClassifierNode,
"1": QuestionClassifierNode,
},
NodeType.HTTP_REQUEST: {
LATEST_VERSION: HttpRequestNode,
"1": HttpRequestNode,
},
NodeType.TOOL: {
LATEST_VERSION: ToolNode,
"1": ToolNode,
},
NodeType.VARIABLE_AGGREGATOR: {
LATEST_VERSION: VariableAggregatorNode,
"1": VariableAggregatorNode,
},
NodeType.LEGACY_VARIABLE_AGGREGATOR: {
LATEST_VERSION: VariableAggregatorNode,
"1": VariableAggregatorNode,
}, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: {
LATEST_VERSION: IterationNode,
"1": IterationNode,
},
NodeType.ITERATION_START: {
LATEST_VERSION: IterationStartNode,
"1": IterationStartNode,
},
NodeType.PARAMETER_EXTRACTOR: {
LATEST_VERSION: ParameterExtractorNode,
"1": ParameterExtractorNode,
},
NodeType.VARIABLE_ASSIGNER: {
LATEST_VERSION: VariableAssignerNodeV2,
"1": VariableAssignerNodeV1,
"2": VariableAssignerNodeV2,
},
NodeType.DOCUMENT_EXTRACTOR: {
LATEST_VERSION: DocumentExtractorNode,
"1": DocumentExtractorNode,
},
NodeType.LIST_OPERATOR: {
LATEST_VERSION: ListOperatorNode,
"1": ListOperatorNode,
},
}

View File

@@ -1,8 +0,0 @@
from .node import VariableAssignerNode
from .node_data import VariableAssignerData, WriteMode
__all__ = [
"VariableAssignerData",
"VariableAssignerNode",
"WriteMode",
]

View File

@@ -0,0 +1,4 @@
class VariableOperatorNodeError(Exception):
"""Base error type, don't use directly."""
pass

View File

@@ -0,0 +1,19 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.variables import Variable
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from extensions.ext_database import db
from models import ConversationVariable
def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableOperatorNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()

View File

@@ -1,2 +0,0 @@
class VariableAssignerNodeError(Exception):
pass

View File

@@ -0,0 +1,3 @@
from .node import VariableAssignerNode
__all__ = ["VariableAssignerNode"]

View File

@@ -1,40 +1,36 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.variables import SegmentType, Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode, BaseNodeData
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from factories import variable_factory
from models import ConversationVariable
from models.workflow import WorkflowNodeExecutionStatus
from .exc import VariableAssignerNodeError
from .node_data import VariableAssignerData, WriteMode
class VariableAssignerNode(BaseNode[VariableAssignerData]):
_node_data_cls: type[BaseNodeData] = VariableAssignerData
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
_node_type = NodeType.VARIABLE_ASSIGNER
def _run(self) -> NodeRunResult:
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableAssignerNodeError("assigned variable not found")
raise VariableOperatorNodeError("assigned variable not found")
match self.node_data.write_mode:
case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError("input value not found")
raise VariableOperatorNodeError("input value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.value})
case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError("input value not found")
raise VariableOperatorNodeError("input value not found")
updated_value = original_variable.value + [income_value.value]
updated_variable = original_variable.model_copy(update={"value": updated_value})
@@ -43,7 +39,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _:
raise VariableAssignerNodeError(f"unsupported write mode: {self.node_data.write_mode}")
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
# Over write the variable.
self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable)
@@ -52,8 +48,8 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
# Update conversation variable.
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
raise VariableAssignerNodeError("conversation_id not found")
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
raise VariableOperatorNodeError("conversation_id not found")
common_helpers.update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -63,18 +59,6 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
)
def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableAssignerNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()
def get_zero_value(t: SegmentType):
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
@@ -86,4 +70,4 @@ def get_zero_value(t: SegmentType):
case SegmentType.NUMBER:
return variable_factory.build_segment(0)
case _:
raise VariableAssignerNodeError(f"unsupported variable type: {t}")
raise VariableOperatorNodeError(f"unsupported variable type: {t}")

View File

@@ -1,6 +1,5 @@
from collections.abc import Sequence
from enum import StrEnum
from typing import Optional
from core.workflow.nodes.base import BaseNodeData
@@ -12,8 +11,6 @@ class WriteMode(StrEnum):
class VariableAssignerData(BaseNodeData):
title: str = "Variable Assigner"
desc: Optional[str] = "Assign a value to a variable"
assigned_variable_selector: Sequence[str]
write_mode: WriteMode
input_variable_selector: Sequence[str]

View File

@@ -0,0 +1,3 @@
from .node import VariableAssignerNode
__all__ = ["VariableAssignerNode"]

View File

@@ -0,0 +1,11 @@
from core.variables import SegmentType
EMPTY_VALUE_MAPPING = {
SegmentType.STRING: "",
SegmentType.NUMBER: 0,
SegmentType.OBJECT: {},
SegmentType.ARRAY_ANY: [],
SegmentType.ARRAY_STRING: [],
SegmentType.ARRAY_NUMBER: [],
SegmentType.ARRAY_OBJECT: [],
}

View File

@@ -0,0 +1,20 @@
from collections.abc import Sequence
from typing import Any
from pydantic import BaseModel
from core.workflow.nodes.base import BaseNodeData
from .enums import InputType, Operation
class VariableOperationItem(BaseModel):
variable_selector: Sequence[str]
input_type: InputType
operation: Operation
value: Any | None = None
class VariableAssignerNodeData(BaseNodeData):
version: str = "2"
items: Sequence[VariableOperationItem]

View File

@@ -0,0 +1,18 @@
from enum import StrEnum
class Operation(StrEnum):
OVER_WRITE = "over-write"
CLEAR = "clear"
APPEND = "append"
EXTEND = "extend"
SET = "set"
ADD = "+="
SUBTRACT = "-="
MULTIPLY = "*="
DIVIDE = "/="
class InputType(StrEnum):
VARIABLE = "variable"
CONSTANT = "constant"

View File

@@ -0,0 +1,31 @@
from collections.abc import Sequence
from typing import Any
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from .enums import InputType, Operation
class OperationNotSupportedError(VariableOperatorNodeError):
def __init__(self, *, operation: Operation, varialbe_type: str):
super().__init__(f"Operation {operation} is not supported for type {varialbe_type}")
class InputTypeNotSupportedError(VariableOperatorNodeError):
def __init__(self, *, input_type: InputType, operation: Operation):
super().__init__(f"Input type {input_type} is not supported for operation {operation}")
class VariableNotFoundError(VariableOperatorNodeError):
def __init__(self, *, variable_selector: Sequence[str]):
super().__init__(f"Variable {variable_selector} not found")
class InvalidInputValueError(VariableOperatorNodeError):
def __init__(self, *, value: Any):
super().__init__(f"Invalid input value {value}")
class ConversationIDNotFoundError(VariableOperatorNodeError):
def __init__(self):
super().__init__("conversation_id not found")

View File

@@ -0,0 +1,91 @@
from typing import Any
from core.variables import SegmentType
from .enums import Operation
def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
match operation:
case Operation.OVER_WRITE | Operation.CLEAR:
return True
case Operation.SET:
return variable_type in {SegmentType.OBJECT, SegmentType.STRING, SegmentType.NUMBER}
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
# Only number variable can be added, subtracted, multiplied or divided
return variable_type == SegmentType.NUMBER
case Operation.APPEND | Operation.EXTEND:
# Only array variable can be appended or extended
return variable_type in {
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_FILE,
}
case _:
return False
def is_variable_input_supported(*, operation: Operation):
if operation in {Operation.SET, Operation.ADD, Operation.SUBTRACT, Operation.MULTIPLY, Operation.DIVIDE}:
return False
return True
def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation):
match variable_type:
case SegmentType.STRING | SegmentType.OBJECT:
return operation in {Operation.OVER_WRITE, Operation.SET}
case SegmentType.NUMBER:
return operation in {
Operation.OVER_WRITE,
Operation.SET,
Operation.ADD,
Operation.SUBTRACT,
Operation.MULTIPLY,
Operation.DIVIDE,
}
case _:
return False
def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, value: Any):
if operation == Operation.CLEAR:
return True
match variable_type:
case SegmentType.STRING:
return isinstance(value, str)
case SegmentType.NUMBER:
if not isinstance(value, int | float):
return False
if operation == Operation.DIVIDE and value == 0:
return False
return True
case SegmentType.OBJECT:
return isinstance(value, dict)
# Array & Append
case SegmentType.ARRAY_ANY if operation == Operation.APPEND:
return isinstance(value, str | float | int | dict)
case SegmentType.ARRAY_STRING if operation == Operation.APPEND:
return isinstance(value, str)
case SegmentType.ARRAY_NUMBER if operation == Operation.APPEND:
return isinstance(value, int | float)
case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND:
return isinstance(value, dict)
# Array & Extend / Overwrite
case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, str | float | int | dict) for item in value)
case SegmentType.ARRAY_STRING if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, str) for item in value)
case SegmentType.ARRAY_NUMBER if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, int | float) for item in value)
case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, dict) for item in value)
case _:
return False

View File

@@ -0,0 +1,159 @@
import json
from typing import Any
from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from models.workflow import WorkflowNodeExecutionStatus
from . import helpers
from .constants import EMPTY_VALUE_MAPPING
from .entities import VariableAssignerNodeData
from .enums import InputType, Operation
from .exc import (
ConversationIDNotFoundError,
InputTypeNotSupportedError,
InvalidInputValueError,
OperationNotSupportedError,
VariableNotFoundError,
)
class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_ASSIGNER
def _run(self) -> NodeRunResult:
inputs = self.node_data.model_dump()
process_data = {}
# NOTE: This node has no outputs
updated_variables: list[Variable] = []
try:
for item in self.node_data.items:
variable = self.graph_runtime_state.variable_pool.get(item.variable_selector)
# ==================== Validation Part
# Check if variable exists
if not isinstance(variable, Variable):
raise VariableNotFoundError(variable_selector=item.variable_selector)
# Check if operation is supported
if not helpers.is_operation_supported(variable_type=variable.value_type, operation=item.operation):
raise OperationNotSupportedError(operation=item.operation, varialbe_type=variable.value_type)
# Check if variable input is supported
if item.input_type == InputType.VARIABLE and not helpers.is_variable_input_supported(
operation=item.operation
):
raise InputTypeNotSupportedError(input_type=InputType.VARIABLE, operation=item.operation)
# Check if constant input is supported
if item.input_type == InputType.CONSTANT and not helpers.is_constant_input_supported(
variable_type=variable.value_type, operation=item.operation
):
raise InputTypeNotSupportedError(input_type=InputType.CONSTANT, operation=item.operation)
# Get value from variable pool
if (
item.input_type == InputType.VARIABLE
and item.operation != Operation.CLEAR
and item.value is not None
):
value = self.graph_runtime_state.variable_pool.get(item.value)
if value is None:
raise VariableNotFoundError(variable_selector=item.value)
# Skip if value is NoneSegment
if value.value_type == SegmentType.NONE:
continue
item.value = value.value
# If set string / bytes / bytearray to object, try convert string to object.
if (
item.operation == Operation.SET
and variable.value_type == SegmentType.OBJECT
and isinstance(item.value, str | bytes | bytearray)
):
try:
item.value = json.loads(item.value)
except json.JSONDecodeError:
raise InvalidInputValueError(value=item.value)
# Check if input value is valid
if not helpers.is_input_value_valid(
variable_type=variable.value_type, operation=item.operation, value=item.value
):
raise InvalidInputValueError(value=item.value)
# ==================== Execution Part
updated_value = self._handle_item(
variable=variable,
operation=item.operation,
value=item.value,
)
variable = variable.model_copy(update={"value": updated_value})
updated_variables.append(variable)
except VariableOperatorNodeError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs,
process_data=process_data,
error=str(e),
)
# Update variables
for variable in updated_variables:
self.graph_runtime_state.variable_pool.add(variable.selector, variable)
process_data[variable.name] = variable.value
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
raise ConversationIDNotFoundError
else:
conversation_id = conversation_id.value
common_helpers.update_conversation_variable(
conversation_id=conversation_id,
variable=variable,
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
)
def _handle_item(
self,
*,
variable: Variable,
operation: Operation,
value: Any,
):
match operation:
case Operation.OVER_WRITE:
return value
case Operation.CLEAR:
return EMPTY_VALUE_MAPPING[variable.value_type]
case Operation.APPEND:
return variable.value + [value]
case Operation.EXTEND:
return variable.value + value
case Operation.SET:
return value
case Operation.ADD:
return variable.value + value
case Operation.SUBTRACT:
return variable.value - value
case Operation.MULTIPLY:
return variable.value * value
case Operation.DIVIDE:
return variable.value / value
case _:
raise OperationNotSupportedError(operation=operation, varialbe_type=variable.value_type)

View File

@@ -2,7 +2,7 @@ import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, Optional
from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
@@ -19,7 +19,7 @@ from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.event import NodeEvent
from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from factories import file_factory
from models.enums import UserFrom
from models.workflow import (
@@ -145,11 +145,8 @@ class WorkflowEntry:
# Get node class
node_type = NodeType(node_config.get("data", {}).get("type"))
node_cls = node_type_classes_mapping.get(node_type)
node_cls = cast(type[BaseNode], node_cls)
if not node_cls:
raise ValueError(f"Node class not found for node type {node_type}")
node_version = node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init variable pool
variable_pool = VariablePool(environment_variables=workflow.environment_variables)