refactor: simplify variable pool key structure and improve type safety (#23732)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2025-08-11 18:10:04 +08:00
committed by GitHub
parent 223c1a8089
commit 577062b93a
10 changed files with 102 additions and 259 deletions

View File

@@ -4,4 +4,4 @@
#
# If the selector length is more than 2, the remaining parts are the keys / indexes paths used
# to extract part of the variable value.
MIN_SELECTORS_LENGTH = 2
SELECTORS_LENGTH = 2

View File

@@ -7,8 +7,8 @@ from pydantic import BaseModel, Field
from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.segments import FileSegment, NoneSegment
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import FileSegment, ObjectSegment
from core.variables.variables import VariableUnion
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.system_variable import SystemVariable
@@ -24,7 +24,7 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field(
variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field(
description="Variables mapping",
default=defaultdict(dict),
)
@@ -36,6 +36,7 @@ class VariablePool(BaseModel):
)
system_variables: SystemVariable = Field(
description="System variables",
default_factory=SystemVariable.empty,
)
environment_variables: Sequence[VariableUnion] = Field(
description="Environment variables.",
@@ -58,23 +59,29 @@ class VariablePool(BaseModel):
def add(self, selector: Sequence[str], value: Any, /) -> None:
"""
Adds a variable to the variable pool.
Add a variable to the variable pool.
NOTE: You should not add a non-Segment value to the variable pool
even if it is allowed now.
This method accepts a selector path and a value, converting the value
to a Variable object if necessary before storing it in the pool.
Args:
selector (Sequence[str]): The selector for the variable.
value (VariableValue): The value of the variable.
selector: A two-element sequence containing [node_id, variable_name].
The selector must have exactly 2 elements to be valid.
value: The value to store. Can be a Variable, Segment, or any value
that can be converted to a Segment (str, int, float, dict, list, File).
Raises:
ValueError: If the selector is invalid.
ValueError: If selector length is not exactly 2 elements.
Returns:
None
Note:
While non-Segment values are currently accepted and automatically
converted, it's recommended to pass Segment or Variable objects directly.
"""
if len(selector) < MIN_SELECTORS_LENGTH:
raise ValueError("Invalid selector")
if len(selector) != SELECTORS_LENGTH:
raise ValueError(
f"Invalid selector: expected {SELECTORS_LENGTH} elements (node_id, variable_name), "
f"got {len(selector)} elements"
)
if isinstance(value, Variable):
variable = value
@@ -84,57 +91,85 @@ class VariablePool(BaseModel):
segment = variable_factory.build_segment(value)
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
key, hash_key = self._selector_to_keys(selector)
node_id, name = self._selector_to_keys(selector)
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable)
self.variable_dictionary[node_id][name] = cast(VariableUnion, variable)
@classmethod
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]:
return selector[0], hash(tuple(selector[1:]))
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:
return selector[0], selector[1]
def _has(self, selector: Sequence[str]) -> bool:
key, hash_key = self._selector_to_keys(selector)
if key not in self.variable_dictionary:
node_id, name = self._selector_to_keys(selector)
if node_id not in self.variable_dictionary:
return False
if hash_key not in self.variable_dictionary[key]:
if name not in self.variable_dictionary[node_id]:
return False
return True
def get(self, selector: Sequence[str], /) -> Segment | None:
"""
Retrieves the value from the variable pool based on the given selector.
Retrieve a variable's value from the pool as a Segment.
This method supports both simple selectors [node_id, variable_name] and
extended selectors that include attribute access for FileSegment and
ObjectSegment types.
Args:
selector (Sequence[str]): The selector used to identify the variable.
selector: A sequence with at least 2 elements:
- [node_id, variable_name]: Returns the full segment
- [node_id, variable_name, attr, ...]: Returns a nested value
from FileSegment (e.g., 'url', 'name') or ObjectSegment
Returns:
Any: The value associated with the given selector.
The Segment associated with the selector, or None if not found.
Returns None if selector has fewer than 2 elements.
Raises:
ValueError: If the selector is invalid.
ValueError: If attempting to access an invalid FileAttribute.
"""
if len(selector) < MIN_SELECTORS_LENGTH:
if len(selector) < SELECTORS_LENGTH:
return None
key, hash_key = self._selector_to_keys(selector)
value: Segment | None = self.variable_dictionary[key].get(hash_key)
node_id, name = self._selector_to_keys(selector)
segment: Segment | None = self.variable_dictionary[node_id].get(name)
if value is None:
selector, attr = selector[:-1], selector[-1]
if segment is None:
return None
if len(selector) == 2:
return segment
if isinstance(segment, FileSegment):
attr = selector[2]
# Python support `attr in FileAttribute` after 3.12
if attr not in {item.value for item in FileAttribute}:
return None
value = self.get(selector)
if not isinstance(value, FileSegment | NoneSegment):
return None
if isinstance(value, FileSegment):
attr = FileAttribute(attr)
attr_value = file_manager.get_attr(file=value.value, attr=attr)
return variable_factory.build_segment(attr_value)
return value
attr = FileAttribute(attr)
attr_value = file_manager.get_attr(file=segment.value, attr=attr)
return variable_factory.build_segment(attr_value)
return value
# Navigate through nested attributes
result: Any = segment
for attr in selector[2:]:
result = self._extract_value(result)
result = self._get_nested_attribute(result, attr)
if result is None:
return None
# Return result as Segment
return result if isinstance(result, Segment) else variable_factory.build_segment(result)
def _extract_value(self, obj: Any) -> Any:
"""Extract the actual value from an ObjectSegment."""
return obj.value if isinstance(obj, ObjectSegment) else obj
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Any:
"""Get a nested attribute from a dictionary-like object."""
if not isinstance(obj, dict):
return None
return obj.get(attr)
def remove(self, selector: Sequence[str], /):
"""

View File

@@ -15,7 +15,7 @@ from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import (
@@ -51,7 +51,6 @@ from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.utils import variable_utils
from libs.flask_utils import preserve_flask_contexts
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -701,11 +700,9 @@ class GraphEngine:
route_node_state.status = RouteNodeState.Status.EXCEPTION
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
# Add variables to variable pool
self.graph_runtime_state.variable_pool.add(
[node.node_id, variable_key], variable_value
)
yield NodeRunExceptionEvent(
error=run_result.error or "System Error",
@@ -758,11 +755,9 @@ class GraphEngine:
# append node output variables to variable pool
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
# Add variables to variable pool
self.graph_runtime_state.variable_pool.add(
[node.node_id, variable_key], variable_value
)
# When setting metadata, convert to dict first
@@ -851,21 +846,6 @@ class GraphEngine:
logger.exception("Node %s run failed", node.title)
raise e
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
"""
Append variables recursively
:param node_id: node id
:param variable_key_list: variable key list
:param variable_value: variable value
:return:
"""
variable_utils.append_variables_recursively(
self.graph_runtime_state.variable_pool,
node_id,
variable_key_list,
variable_value,
)
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
"""
Check timeout

View File

@@ -4,7 +4,7 @@ from typing import Any, TypeVar
from pydantic import BaseModel
from core.variables import Segment
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.consts import SELECTORS_LENGTH
from core.variables.types import SegmentType
# Use double underscore (`__`) prefix for internal variables
@@ -23,7 +23,7 @@ _T = TypeVar("_T", bound=MutableMapping[str, Any])
def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable:
if len(selector) < MIN_SELECTORS_LENGTH:
if len(selector) < SELECTORS_LENGTH:
raise Exception("selector too short")
node_id, var_name = selector[:2]
return UpdatedVariable(

View File

@@ -4,7 +4,7 @@ from typing import Any, Optional, cast
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.node_entities import NodeRunResult
@@ -46,7 +46,7 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
selector = item.value
if not isinstance(selector, list):
raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}")
if len(selector) < MIN_SELECTORS_LENGTH:
if len(selector) < SELECTORS_LENGTH:
raise InvalidDataError(f"selector too short, {node_id=}, {item=}")
selector_str = ".".join(selector)
key = f"{node_id}.#{selector_str}#"

View File

@@ -1,29 +0,0 @@
from core.variables.segments import ObjectSegment, Segment
from core.workflow.entities.variable_pool import VariablePool, VariableValue
def append_variables_recursively(
pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue | Segment
):
"""
Append variables recursively
:param pool: variable pool to append variables to
:param node_id: node id
:param variable_key_list: variable key list
:param variable_value: variable value
:return:
"""
pool.add([node_id] + variable_key_list, variable_value)
# if variable_value is a dict, then recursively append variables
if isinstance(variable_value, ObjectSegment):
variable_dict = variable_value.value
elif isinstance(variable_value, dict):
variable_dict = variable_value
else:
return
for key, value in variable_dict.items():
# construct new key list
new_key_list = variable_key_list + [key]
append_variables_recursively(pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value)

View File

@@ -3,9 +3,8 @@ from collections.abc import Mapping, Sequence
from typing import Any, Protocol
from core.variables import Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.utils import variable_utils
class VariableLoader(Protocol):
@@ -78,7 +77,7 @@ def load_into_variable_pool(
variables_to_load.append(list(selector))
loaded = variable_loader.load_variables(variables_to_load)
for var in loaded:
assert len(var.selector) >= MIN_SELECTORS_LENGTH, f"Invalid variable {var}"
variable_utils.append_variables_recursively(
variable_pool, node_id=var.selector[0], variable_key_list=list(var.selector[1:]), variable_value=var
)
assert len(var.selector) >= SELECTORS_LENGTH, f"Invalid variable {var}"
# Add variable directly to the pool
# The variable pool expects 2-element selectors [node_id, variable_name]
variable_pool.add([var.selector[0], var.selector[1]], var)

View File

@@ -13,7 +13,7 @@ from sqlalchemy.sql.expression import and_, or_
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.variables import Segment, StringSegment, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import ArrayFileSegment, FileSegment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
@@ -147,7 +147,7 @@ class WorkflowDraftVariableService:
) -> list[WorkflowDraftVariable]:
ors = []
for selector in selectors:
assert len(selector) >= MIN_SELECTORS_LENGTH, f"Invalid selector to get: {selector}"
assert len(selector) >= SELECTORS_LENGTH, f"Invalid selector to get: {selector}"
node_id, name = selector[:2]
ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name))
@@ -608,7 +608,7 @@ class DraftVariableSaver:
for item in updated_variables:
selector = item.selector
if len(selector) < MIN_SELECTORS_LENGTH:
if len(selector) < SELECTORS_LENGTH:
raise Exception("selector too short")
# NOTE(QuantumGhost): only the following two kinds of variable could be updated by
# VariableAssigner: ConversationVariable and iteration variable.

View File

@@ -69,8 +69,12 @@ def test_get_file_attribute(pool, file):
def test_use_long_selector(pool):
pool.add(("node_1", "part_1", "part_2"), StringSegment(value="test_value"))
# The add method now only accepts 2-element selectors (node_id, variable_name)
# Store nested data as an ObjectSegment instead
nested_data = {"part_2": "test_value"}
pool.add(("node_1", "part_1"), ObjectSegment(value=nested_data))
# The get method supports longer selectors for nested access
result = pool.get(("node_1", "part_1", "part_2"))
assert result is not None
assert result.value == "test_value"
@@ -280,8 +284,10 @@ class TestVariablePoolSerialization:
pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file]))
pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}]))
# Add nested variables
pool.add((self._NODE3_ID, "nested", "deep", "var"), StringSegment(value="deep_value"))
# Add nested variables as ObjectSegment
# The add method only accepts 2-element selectors
nested_obj = {"deep": {"var": "deep_value"}}
pool.add((self._NODE3_ID, "nested"), ObjectSegment(value=nested_obj))
def test_system_variables(self):
sys_vars = SystemVariable(

View File

@@ -1,148 +0,0 @@
from typing import Any
from core.variables.segments import ObjectSegment, StringSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.utils.variable_utils import append_variables_recursively
class TestAppendVariablesRecursively:
"""Test cases for append_variables_recursively function"""
def test_append_simple_dict_value(self):
"""Test appending a simple dictionary value"""
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["output"]
variable_value = {"name": "John", "age": 30}
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
# Check that the main variable is added
main_var = pool.get([node_id] + variable_key_list)
assert main_var is not None
assert main_var.value == variable_value
# Check that nested variables are added recursively
name_var = pool.get([node_id] + variable_key_list + ["name"])
assert name_var is not None
assert name_var.value == "John"
age_var = pool.get([node_id] + variable_key_list + ["age"])
assert age_var is not None
assert age_var.value == 30
def test_append_object_segment_value(self):
"""Test appending an ObjectSegment value"""
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["result"]
# Create an ObjectSegment
obj_data = {"status": "success", "code": 200}
variable_value = ObjectSegment(value=obj_data)
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
# Check that the main variable is added
main_var = pool.get([node_id] + variable_key_list)
assert main_var is not None
assert isinstance(main_var, ObjectSegment)
assert main_var.value == obj_data
# Check that nested variables are added recursively
status_var = pool.get([node_id] + variable_key_list + ["status"])
assert status_var is not None
assert status_var.value == "success"
code_var = pool.get([node_id] + variable_key_list + ["code"])
assert code_var is not None
assert code_var.value == 200
def test_append_nested_dict_value(self):
"""Test appending a nested dictionary value"""
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["data"]
variable_value = {
"user": {
"profile": {"name": "Alice", "email": "alice@example.com"},
"settings": {"theme": "dark", "notifications": True},
},
"metadata": {"version": "1.0", "timestamp": 1234567890},
}
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
# Check deeply nested variables
name_var = pool.get([node_id] + variable_key_list + ["user", "profile", "name"])
assert name_var is not None
assert name_var.value == "Alice"
email_var = pool.get([node_id] + variable_key_list + ["user", "profile", "email"])
assert email_var is not None
assert email_var.value == "alice@example.com"
theme_var = pool.get([node_id] + variable_key_list + ["user", "settings", "theme"])
assert theme_var is not None
assert theme_var.value == "dark"
notifications_var = pool.get([node_id] + variable_key_list + ["user", "settings", "notifications"])
assert notifications_var is not None
assert notifications_var.value == 1 # Boolean True is converted to integer 1
version_var = pool.get([node_id] + variable_key_list + ["metadata", "version"])
assert version_var is not None
assert version_var.value == "1.0"
def test_append_non_dict_value(self):
"""Test appending a non-dictionary value (should not recurse)"""
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["simple"]
variable_value = "simple_string"
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
# Check that only the main variable is added
main_var = pool.get([node_id] + variable_key_list)
assert main_var is not None
assert main_var.value == variable_value
# Ensure no additional variables are created
assert len(pool.variable_dictionary[node_id]) == 1
def test_append_segment_non_object_value(self):
"""Test appending a Segment that is not ObjectSegment (should not recurse)"""
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["text"]
variable_value = StringSegment(value="Hello World")
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
# Check that only the main variable is added
main_var = pool.get([node_id] + variable_key_list)
assert main_var is not None
assert isinstance(main_var, StringSegment)
assert main_var.value == "Hello World"
# Ensure no additional variables are created
assert len(pool.variable_dictionary[node_id]) == 1
def test_append_empty_dict_value(self):
"""Test appending an empty dictionary value"""
pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["empty"]
variable_value: dict[str, Any] = {}
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
# Check that the main variable is added
main_var = pool.get([node_id] + variable_key_list)
assert main_var is not None
assert main_var.value == {}
# Ensure only the main variable is created (no recursion for empty dict)
assert len(pool.variable_dictionary[node_id]) == 1