refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025)
refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025) This PR addresses serialization issues in the VariablePool model by separating the `value_type` tags for `IntegerSegment`/`FloatSegment` and `IntegerVariable`/`FloatVariable`. Previously, both Integer and Float types shared the same `SegmentType.NUMBER` tag, causing conflicts during serialization. Key changes: - Introduce distinct `value_type` tags for Integer and Float segments/variables - Add `VariableUnion` and `SegmentUnion` types for proper type discrimination - Leverage Pydantic's discriminated union feature for seamless serialization/deserialization - Enable accurate serialization of data structures containing these types Closes #22024.
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Union
|
||||
from typing import Annotated, Any, Union, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -9,8 +9,9 @@ from core.file import File, FileAttribute, file_manager
|
||||
from core.variables import Segment, SegmentGroup, Variable
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.variables.segments import FileSegment, NoneSegment
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from factories import variable_factory
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list, File]
|
||||
@@ -23,31 +24,31 @@ class VariablePool(BaseModel):
|
||||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||
# elements of the selector except the first one.
|
||||
variable_dictionary: dict[str, dict[int, Segment]] = Field(
|
||||
variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field(
|
||||
description="Variables mapping",
|
||||
default=defaultdict(dict),
|
||||
)
|
||||
# TODO: This user inputs is not used for pool.
|
||||
|
||||
# The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere.
|
||||
user_inputs: Mapping[str, Any] = Field(
|
||||
description="User inputs",
|
||||
default_factory=dict,
|
||||
)
|
||||
system_variables: Mapping[SystemVariableKey, Any] = Field(
|
||||
system_variables: SystemVariable = Field(
|
||||
description="System variables",
|
||||
default_factory=dict,
|
||||
)
|
||||
environment_variables: Sequence[Variable] = Field(
|
||||
environment_variables: Sequence[VariableUnion] = Field(
|
||||
description="Environment variables.",
|
||||
default_factory=list,
|
||||
)
|
||||
conversation_variables: Sequence[Variable] = Field(
|
||||
conversation_variables: Sequence[VariableUnion] = Field(
|
||||
description="Conversation variables.",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
def model_post_init(self, context: Any, /) -> None:
|
||||
for key, value in self.system_variables.items():
|
||||
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
|
||||
# Create a mapping from field names to SystemVariableKey enum values
|
||||
self._add_system_variables(self.system_variables)
|
||||
# Add environment variables to the variable pool
|
||||
for var in self.environment_variables:
|
||||
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
||||
@@ -83,8 +84,22 @@ class VariablePool(BaseModel):
|
||||
segment = variable_factory.build_segment(value)
|
||||
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
|
||||
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
self.variable_dictionary[selector[0]][hash_key] = variable
|
||||
key, hash_key = self._selector_to_keys(selector)
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable)
|
||||
|
||||
@classmethod
|
||||
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]:
|
||||
return selector[0], hash(tuple(selector[1:]))
|
||||
|
||||
def _has(self, selector: Sequence[str]) -> bool:
|
||||
key, hash_key = self._selector_to_keys(selector)
|
||||
if key not in self.variable_dictionary:
|
||||
return False
|
||||
if hash_key not in self.variable_dictionary[key]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||
"""
|
||||
@@ -102,8 +117,8 @@ class VariablePool(BaseModel):
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
return None
|
||||
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
value = self.variable_dictionary[selector[0]].get(hash_key)
|
||||
key, hash_key = self._selector_to_keys(selector)
|
||||
value: Segment | None = self.variable_dictionary[key].get(hash_key)
|
||||
|
||||
if value is None:
|
||||
selector, attr = selector[:-1], selector[-1]
|
||||
@@ -136,8 +151,9 @@ class VariablePool(BaseModel):
|
||||
if len(selector) == 1:
|
||||
self.variable_dictionary[selector[0]] = {}
|
||||
return
|
||||
key, hash_key = self._selector_to_keys(selector)
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
self.variable_dictionary[selector[0]].pop(hash_key, None)
|
||||
self.variable_dictionary[key].pop(hash_key, None)
|
||||
|
||||
def convert_template(self, template: str, /):
|
||||
parts = VARIABLE_PATTERN.split(template)
|
||||
@@ -154,3 +170,20 @@ class VariablePool(BaseModel):
|
||||
if isinstance(segment, FileSegment):
|
||||
return segment
|
||||
return None
|
||||
|
||||
def _add_system_variables(self, system_variable: SystemVariable):
|
||||
sys_var_mapping = system_variable.to_dict()
|
||||
for key, value in sys_var_mapping.items():
|
||||
if value is None:
|
||||
continue
|
||||
selector = (SYSTEM_VARIABLE_NODE_ID, key)
|
||||
# If the system variable already exists, do not add it again.
|
||||
# This ensures that we can keep the id of the system variables intact.
|
||||
if self._has(selector):
|
||||
continue
|
||||
self.add(selector, value) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "VariablePool":
|
||||
"""Create an empty variable pool."""
|
||||
return cls(system_variables=SystemVariable.empty())
|
||||
|
@@ -17,8 +17,12 @@ class GraphRuntimeState(BaseModel):
|
||||
"""total tokens"""
|
||||
llm_usage: LLMUsage = LLMUsage.empty_usage()
|
||||
"""llm usage info"""
|
||||
|
||||
# The `outputs` field stores the final output values generated by executing workflows or chatflows.
|
||||
#
|
||||
# Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent
|
||||
# after a serialization and deserialization round trip.
|
||||
outputs: dict[str, Any] = {}
|
||||
"""outputs"""
|
||||
|
||||
node_run_steps: int = 0
|
||||
"""node run steps"""
|
||||
|
@@ -1,11 +1,29 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Annotated, Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import AfterValidator, BaseModel, Field
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
_VALID_VAR_TYPE = frozenset(
|
||||
[
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
|
||||
if seg_type not in _VALID_VAR_TYPE:
|
||||
raise ValueError(...)
|
||||
return seg_type
|
||||
|
||||
|
||||
class LoopVariableData(BaseModel):
|
||||
"""
|
||||
@@ -13,7 +31,7 @@ class LoopVariableData(BaseModel):
|
||||
"""
|
||||
|
||||
label: str
|
||||
var_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
|
||||
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
|
||||
value_type: Literal["variable", "constant"]
|
||||
value: Optional[Any | list[str]] = None
|
||||
|
||||
|
@@ -7,14 +7,9 @@ from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.variables import (
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayStringSegment,
|
||||
IntegerSegment,
|
||||
ObjectSegment,
|
||||
Segment,
|
||||
SegmentType,
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
@@ -39,6 +34,7 @@ from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
@@ -505,23 +501,21 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
return variable_mapping
|
||||
|
||||
@staticmethod
|
||||
def _get_segment_for_constant(var_type: str, value: Any) -> Segment:
|
||||
def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment:
|
||||
"""Get the appropriate segment type for a constant value."""
|
||||
segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = {
|
||||
"string": (StringSegment, SegmentType.STRING),
|
||||
"number": (IntegerSegment, SegmentType.NUMBER),
|
||||
"object": (ObjectSegment, SegmentType.OBJECT),
|
||||
"array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING),
|
||||
"array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER),
|
||||
"array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT),
|
||||
}
|
||||
if var_type in ["array[string]", "array[number]", "array[object]"]:
|
||||
if value:
|
||||
if value and isinstance(value, str):
|
||||
value = json.loads(value)
|
||||
else:
|
||||
value = []
|
||||
segment_info = segment_mapping.get(var_type)
|
||||
if not segment_info:
|
||||
raise ValueError(f"Invalid variable type: {var_type}")
|
||||
segment_class, value_type = segment_info
|
||||
return segment_class(value=value, value_type=value_type)
|
||||
try:
|
||||
return build_segment_with_type(var_type, value)
|
||||
except TypeMismatchError as type_exc:
|
||||
# Attempt to parse the value as a JSON-encoded string, if applicable.
|
||||
if not isinstance(value, str):
|
||||
raise
|
||||
try:
|
||||
value = json.loads(value)
|
||||
except ValueError:
|
||||
raise type_exc
|
||||
return build_segment_with_type(var_type, value)
|
||||
|
@@ -16,7 +16,7 @@ class StartNode(BaseNode[StartNodeData]):
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
||||
system_inputs = self.graph_runtime_state.variable_pool.system_variables
|
||||
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
|
||||
|
||||
# TODO: System variables should be directly accessible, no need for special handling
|
||||
# Set system variables as node outputs.
|
||||
|
@@ -130,6 +130,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
|
||||
|
||||
def get_zero_value(t: SegmentType):
|
||||
# TODO(QuantumGhost): this should be a method of `SegmentType`.
|
||||
match t:
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
|
||||
return variable_factory.build_segment([])
|
||||
@@ -137,6 +138,10 @@ def get_zero_value(t: SegmentType):
|
||||
return variable_factory.build_segment({})
|
||||
case SegmentType.STRING:
|
||||
return variable_factory.build_segment("")
|
||||
case SegmentType.INTEGER:
|
||||
return variable_factory.build_segment(0)
|
||||
case SegmentType.FLOAT:
|
||||
return variable_factory.build_segment(0.0)
|
||||
case SegmentType.NUMBER:
|
||||
return variable_factory.build_segment(0)
|
||||
case _:
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from core.variables import SegmentType
|
||||
|
||||
# Note: This mapping is duplicated with `get_zero_value`. Consider refactoring to avoid redundancy.
|
||||
EMPTY_VALUE_MAPPING = {
|
||||
SegmentType.STRING: "",
|
||||
SegmentType.NUMBER: 0,
|
||||
|
@@ -10,10 +10,16 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
|
||||
case Operation.OVER_WRITE | Operation.CLEAR:
|
||||
return True
|
||||
case Operation.SET:
|
||||
return variable_type in {SegmentType.OBJECT, SegmentType.STRING, SegmentType.NUMBER}
|
||||
return variable_type in {
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.INTEGER,
|
||||
SegmentType.FLOAT,
|
||||
}
|
||||
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
|
||||
# Only number variable can be added, subtracted, multiplied or divided
|
||||
return variable_type == SegmentType.NUMBER
|
||||
return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}
|
||||
case Operation.APPEND | Operation.EXTEND:
|
||||
# Only array variable can be appended or extended
|
||||
return variable_type in {
|
||||
@@ -46,7 +52,7 @@ def is_constant_input_supported(*, variable_type: SegmentType, operation: Operat
|
||||
match variable_type:
|
||||
case SegmentType.STRING | SegmentType.OBJECT:
|
||||
return operation in {Operation.OVER_WRITE, Operation.SET}
|
||||
case SegmentType.NUMBER:
|
||||
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
|
||||
return operation in {
|
||||
Operation.OVER_WRITE,
|
||||
Operation.SET,
|
||||
@@ -66,7 +72,7 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
|
||||
case SegmentType.STRING:
|
||||
return isinstance(value, str)
|
||||
|
||||
case SegmentType.NUMBER:
|
||||
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
|
||||
if not isinstance(value, int | float):
|
||||
return False
|
||||
if operation == Operation.DIVIDE and value == 0:
|
||||
|
89
api/core/workflow/system_variable.py
Normal file
89
api/core/workflow/system_variable.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from core.file.models import File
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
|
||||
|
||||
class SystemVariable(BaseModel):
|
||||
"""A model for managing system variables.
|
||||
|
||||
Fields with a value of `None` are treated as absent and will not be included
|
||||
in the variable pool.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
serialize_by_alias=True,
|
||||
validate_by_alias=True,
|
||||
)
|
||||
|
||||
user_id: str | None = None
|
||||
|
||||
# Ideally, `app_id` and `workflow_id` should be required and not `None`.
|
||||
# However, there are scenarios in the codebase where these fields are not set.
|
||||
# To maintain compatibility, they are marked as optional here.
|
||||
app_id: str | None = None
|
||||
workflow_id: str | None = None
|
||||
|
||||
files: Sequence[File] = Field(default_factory=list)
|
||||
|
||||
# NOTE: The `workflow_execution_id` field was previously named `workflow_run_id`.
|
||||
# To maintain compatibility with existing workflows, it must be serialized
|
||||
# as `workflow_run_id` in dictionaries or JSON objects, and also referenced
|
||||
# as `workflow_run_id` in the variable pool.
|
||||
workflow_execution_id: str | None = Field(
|
||||
validation_alias=AliasChoices("workflow_execution_id", "workflow_run_id"),
|
||||
serialization_alias="workflow_run_id",
|
||||
default=None,
|
||||
)
|
||||
# Chatflow related fields.
|
||||
query: str | None = None
|
||||
conversation_id: str | None = None
|
||||
dialogue_count: int | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_json_fields(cls, data):
|
||||
if isinstance(data, dict):
|
||||
# For JSON validation, only allow workflow_run_id
|
||||
if "workflow_execution_id" in data and "workflow_run_id" not in data:
|
||||
# This is likely from direct instantiation, allow it
|
||||
return data
|
||||
elif "workflow_execution_id" in data and "workflow_run_id" in data:
|
||||
# Both present, remove workflow_execution_id
|
||||
data = data.copy()
|
||||
data.pop("workflow_execution_id")
|
||||
return data
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "SystemVariable":
|
||||
return cls()
|
||||
|
||||
def to_dict(self) -> dict[SystemVariableKey, Any]:
|
||||
# NOTE: This method is provided for compatibility with legacy code.
|
||||
# New code should use the `SystemVariable` object directly instead of converting
|
||||
# it to a dictionary, as this conversion results in the loss of type information
|
||||
# for each key, making static analysis more difficult.
|
||||
|
||||
d: dict[SystemVariableKey, Any] = {
|
||||
SystemVariableKey.FILES: self.files,
|
||||
}
|
||||
if self.user_id is not None:
|
||||
d[SystemVariableKey.USER_ID] = self.user_id
|
||||
if self.app_id is not None:
|
||||
d[SystemVariableKey.APP_ID] = self.app_id
|
||||
if self.workflow_id is not None:
|
||||
d[SystemVariableKey.WORKFLOW_ID] = self.workflow_id
|
||||
if self.workflow_execution_id is not None:
|
||||
d[SystemVariableKey.WORKFLOW_EXECUTION_ID] = self.workflow_execution_id
|
||||
if self.query is not None:
|
||||
d[SystemVariableKey.QUERY] = self.query
|
||||
if self.conversation_id is not None:
|
||||
d[SystemVariableKey.CONVERSATION_ID] = self.conversation_id
|
||||
if self.dialogue_count is not None:
|
||||
d[SystemVariableKey.DIALOGUE_COUNT] = self.dialogue_count
|
||||
return d
|
@@ -26,6 +26,7 @@ from core.workflow.entities.workflow_node_execution import (
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
@@ -43,7 +44,7 @@ class WorkflowCycleManager:
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||
workflow_system_variables: dict[SystemVariableKey, Any],
|
||||
workflow_system_variables: SystemVariable,
|
||||
workflow_info: CycleManagerWorkflowInfo,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
@@ -56,17 +57,22 @@ class WorkflowCycleManager:
|
||||
|
||||
def handle_workflow_run_start(self) -> WorkflowExecution:
|
||||
inputs = {**self._application_generate_entity.inputs}
|
||||
for key, value in (self._workflow_system_variables or {}).items():
|
||||
if key.value == "conversation":
|
||||
continue
|
||||
inputs[f"sys.{key.value}"] = value
|
||||
|
||||
# Iterate over SystemVariable fields using Pydantic's model_fields
|
||||
if self._workflow_system_variables:
|
||||
for field_name, value in self._workflow_system_variables.to_dict().items():
|
||||
if field_name == SystemVariableKey.CONVERSATION_ID:
|
||||
continue
|
||||
inputs[f"sys.{field_name}"] = value
|
||||
|
||||
# handle special values
|
||||
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
|
||||
|
||||
# init workflow run
|
||||
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
|
||||
execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_EXECUTION_ID) or uuid4())
|
||||
execution_id = str(
|
||||
self._workflow_system_variables.workflow_execution_id if self._workflow_system_variables else None
|
||||
) or str(uuid4())
|
||||
execution = WorkflowExecution.new(
|
||||
id_=execution_id,
|
||||
workflow_id=self._workflow_info.workflow_id,
|
||||
|
@@ -21,6 +21,7 @@ from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.event import NodeEvent
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from factories import file_factory
|
||||
from models.enums import UserFrom
|
||||
@@ -254,7 +255,7 @@ class WorkflowEntry:
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
|
Reference in New Issue
Block a user