refactor: simplify variable pool key structure and improve type safety (#23732)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -4,4 +4,4 @@
|
||||
#
|
||||
# If the selector length is more than 2, the remaining parts are the keys / indexes paths used
|
||||
# to extract part of the variable value.
|
||||
MIN_SELECTORS_LENGTH = 2
|
||||
SELECTORS_LENGTH = 2
|
||||
|
@@ -7,8 +7,8 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from core.file import File, FileAttribute, file_manager
|
||||
from core.variables import Segment, SegmentGroup, Variable
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.variables.segments import FileSegment, NoneSegment
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.variables.segments import FileSegment, ObjectSegment
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
@@ -24,7 +24,7 @@ class VariablePool(BaseModel):
|
||||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||
# elements of the selector except the first one.
|
||||
variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field(
|
||||
variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field(
|
||||
description="Variables mapping",
|
||||
default=defaultdict(dict),
|
||||
)
|
||||
@@ -36,6 +36,7 @@ class VariablePool(BaseModel):
|
||||
)
|
||||
system_variables: SystemVariable = Field(
|
||||
description="System variables",
|
||||
default_factory=SystemVariable.empty,
|
||||
)
|
||||
environment_variables: Sequence[VariableUnion] = Field(
|
||||
description="Environment variables.",
|
||||
@@ -58,23 +59,29 @@ class VariablePool(BaseModel):
|
||||
|
||||
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
||||
"""
|
||||
Adds a variable to the variable pool.
|
||||
Add a variable to the variable pool.
|
||||
|
||||
NOTE: You should not add a non-Segment value to the variable pool
|
||||
even if it is allowed now.
|
||||
This method accepts a selector path and a value, converting the value
|
||||
to a Variable object if necessary before storing it in the pool.
|
||||
|
||||
Args:
|
||||
selector (Sequence[str]): The selector for the variable.
|
||||
value (VariableValue): The value of the variable.
|
||||
selector: A two-element sequence containing [node_id, variable_name].
|
||||
The selector must have exactly 2 elements to be valid.
|
||||
value: The value to store. Can be a Variable, Segment, or any value
|
||||
that can be converted to a Segment (str, int, float, dict, list, File).
|
||||
|
||||
Raises:
|
||||
ValueError: If the selector is invalid.
|
||||
ValueError: If selector length is not exactly 2 elements.
|
||||
|
||||
Returns:
|
||||
None
|
||||
Note:
|
||||
While non-Segment values are currently accepted and automatically
|
||||
converted, it's recommended to pass Segment or Variable objects directly.
|
||||
"""
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
raise ValueError("Invalid selector")
|
||||
if len(selector) != SELECTORS_LENGTH:
|
||||
raise ValueError(
|
||||
f"Invalid selector: expected {SELECTORS_LENGTH} elements (node_id, variable_name), "
|
||||
f"got {len(selector)} elements"
|
||||
)
|
||||
|
||||
if isinstance(value, Variable):
|
||||
variable = value
|
||||
@@ -84,57 +91,85 @@ class VariablePool(BaseModel):
|
||||
segment = variable_factory.build_segment(value)
|
||||
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
|
||||
|
||||
key, hash_key = self._selector_to_keys(selector)
|
||||
node_id, name = self._selector_to_keys(selector)
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable)
|
||||
self.variable_dictionary[node_id][name] = cast(VariableUnion, variable)
|
||||
|
||||
@classmethod
|
||||
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]:
|
||||
return selector[0], hash(tuple(selector[1:]))
|
||||
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:
|
||||
return selector[0], selector[1]
|
||||
|
||||
def _has(self, selector: Sequence[str]) -> bool:
|
||||
key, hash_key = self._selector_to_keys(selector)
|
||||
if key not in self.variable_dictionary:
|
||||
node_id, name = self._selector_to_keys(selector)
|
||||
if node_id not in self.variable_dictionary:
|
||||
return False
|
||||
if hash_key not in self.variable_dictionary[key]:
|
||||
if name not in self.variable_dictionary[node_id]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||
"""
|
||||
Retrieves the value from the variable pool based on the given selector.
|
||||
Retrieve a variable's value from the pool as a Segment.
|
||||
|
||||
This method supports both simple selectors [node_id, variable_name] and
|
||||
extended selectors that include attribute access for FileSegment and
|
||||
ObjectSegment types.
|
||||
|
||||
Args:
|
||||
selector (Sequence[str]): The selector used to identify the variable.
|
||||
selector: A sequence with at least 2 elements:
|
||||
- [node_id, variable_name]: Returns the full segment
|
||||
- [node_id, variable_name, attr, ...]: Returns a nested value
|
||||
from FileSegment (e.g., 'url', 'name') or ObjectSegment
|
||||
|
||||
Returns:
|
||||
Any: The value associated with the given selector.
|
||||
The Segment associated with the selector, or None if not found.
|
||||
Returns None if selector has fewer than 2 elements.
|
||||
|
||||
Raises:
|
||||
ValueError: If the selector is invalid.
|
||||
ValueError: If attempting to access an invalid FileAttribute.
|
||||
"""
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
if len(selector) < SELECTORS_LENGTH:
|
||||
return None
|
||||
|
||||
key, hash_key = self._selector_to_keys(selector)
|
||||
value: Segment | None = self.variable_dictionary[key].get(hash_key)
|
||||
node_id, name = self._selector_to_keys(selector)
|
||||
segment: Segment | None = self.variable_dictionary[node_id].get(name)
|
||||
|
||||
if value is None:
|
||||
selector, attr = selector[:-1], selector[-1]
|
||||
if segment is None:
|
||||
return None
|
||||
|
||||
if len(selector) == 2:
|
||||
return segment
|
||||
|
||||
if isinstance(segment, FileSegment):
|
||||
attr = selector[2]
|
||||
# Python support `attr in FileAttribute` after 3.12
|
||||
if attr not in {item.value for item in FileAttribute}:
|
||||
return None
|
||||
value = self.get(selector)
|
||||
if not isinstance(value, FileSegment | NoneSegment):
|
||||
return None
|
||||
if isinstance(value, FileSegment):
|
||||
attr = FileAttribute(attr)
|
||||
attr_value = file_manager.get_attr(file=value.value, attr=attr)
|
||||
attr_value = file_manager.get_attr(file=segment.value, attr=attr)
|
||||
return variable_factory.build_segment(attr_value)
|
||||
return value
|
||||
|
||||
return value
|
||||
# Navigate through nested attributes
|
||||
result: Any = segment
|
||||
for attr in selector[2:]:
|
||||
result = self._extract_value(result)
|
||||
result = self._get_nested_attribute(result, attr)
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
# Return result as Segment
|
||||
return result if isinstance(result, Segment) else variable_factory.build_segment(result)
|
||||
|
||||
def _extract_value(self, obj: Any) -> Any:
|
||||
"""Extract the actual value from an ObjectSegment."""
|
||||
return obj.value if isinstance(obj, ObjectSegment) else obj
|
||||
|
||||
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Any:
|
||||
"""Get a nested attribute from a dictionary-like object."""
|
||||
if not isinstance(obj, dict):
|
||||
return None
|
||||
return obj.get(attr)
|
||||
|
||||
def remove(self, selector: Sequence[str], /):
|
||||
"""
|
||||
|
@@ -15,7 +15,7 @@ from configs import dify_config
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
@@ -51,7 +51,6 @@ from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.utils import variable_utils
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
@@ -701,11 +700,9 @@ class GraphEngine:
|
||||
route_node_state.status = RouteNodeState.Status.EXCEPTION
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
# Add variables to variable pool
|
||||
self.graph_runtime_state.variable_pool.add(
|
||||
[node.node_id, variable_key], variable_value
|
||||
)
|
||||
yield NodeRunExceptionEvent(
|
||||
error=run_result.error or "System Error",
|
||||
@@ -758,11 +755,9 @@ class GraphEngine:
|
||||
# append node output variables to variable pool
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
# Add variables to variable pool
|
||||
self.graph_runtime_state.variable_pool.add(
|
||||
[node.node_id, variable_key], variable_value
|
||||
)
|
||||
|
||||
# When setting metadata, convert to dict first
|
||||
@@ -851,21 +846,6 @@ class GraphEngine:
|
||||
logger.exception("Node %s run failed", node.title)
|
||||
raise e
|
||||
|
||||
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
||||
"""
|
||||
Append variables recursively
|
||||
:param node_id: node id
|
||||
:param variable_key_list: variable key list
|
||||
:param variable_value: variable value
|
||||
:return:
|
||||
"""
|
||||
variable_utils.append_variables_recursively(
|
||||
self.graph_runtime_state.variable_pool,
|
||||
node_id,
|
||||
variable_key_list,
|
||||
variable_value,
|
||||
)
|
||||
|
||||
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
|
||||
"""
|
||||
Check timeout
|
||||
|
@@ -4,7 +4,7 @@ from typing import Any, TypeVar
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.variables import Segment
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
# Use double underscore (`__`) prefix for internal variables
|
||||
@@ -23,7 +23,7 @@ _T = TypeVar("_T", bound=MutableMapping[str, Any])
|
||||
|
||||
|
||||
def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable:
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
if len(selector) < SELECTORS_LENGTH:
|
||||
raise Exception("selector too short")
|
||||
node_id, var_name = selector[:2]
|
||||
return UpdatedVariable(
|
||||
|
@@ -4,7 +4,7 @@ from typing import Any, Optional, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
@@ -46,7 +46,7 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
|
||||
selector = item.value
|
||||
if not isinstance(selector, list):
|
||||
raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}")
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
if len(selector) < SELECTORS_LENGTH:
|
||||
raise InvalidDataError(f"selector too short, {node_id=}, {item=}")
|
||||
selector_str = ".".join(selector)
|
||||
key = f"{node_id}.#{selector_str}#"
|
||||
|
@@ -1,29 +0,0 @@
|
||||
from core.variables.segments import ObjectSegment, Segment
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
|
||||
|
||||
def append_variables_recursively(
|
||||
pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue | Segment
|
||||
):
|
||||
"""
|
||||
Append variables recursively
|
||||
:param pool: variable pool to append variables to
|
||||
:param node_id: node id
|
||||
:param variable_key_list: variable key list
|
||||
:param variable_value: variable value
|
||||
:return:
|
||||
"""
|
||||
pool.add([node_id] + variable_key_list, variable_value)
|
||||
|
||||
# if variable_value is a dict, then recursively append variables
|
||||
if isinstance(variable_value, ObjectSegment):
|
||||
variable_dict = variable_value.value
|
||||
elif isinstance(variable_value, dict):
|
||||
variable_dict = variable_value
|
||||
else:
|
||||
return
|
||||
|
||||
for key, value in variable_dict.items():
|
||||
# construct new key list
|
||||
new_key_list = variable_key_list + [key]
|
||||
append_variables_recursively(pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value)
|
@@ -3,9 +3,8 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.variables import Variable
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.utils import variable_utils
|
||||
|
||||
|
||||
class VariableLoader(Protocol):
|
||||
@@ -78,7 +77,7 @@ def load_into_variable_pool(
|
||||
variables_to_load.append(list(selector))
|
||||
loaded = variable_loader.load_variables(variables_to_load)
|
||||
for var in loaded:
|
||||
assert len(var.selector) >= MIN_SELECTORS_LENGTH, f"Invalid variable {var}"
|
||||
variable_utils.append_variables_recursively(
|
||||
variable_pool, node_id=var.selector[0], variable_key_list=list(var.selector[1:]), variable_value=var
|
||||
)
|
||||
assert len(var.selector) >= SELECTORS_LENGTH, f"Invalid variable {var}"
|
||||
# Add variable directly to the pool
|
||||
# The variable pool expects 2-element selectors [node_id, variable_name]
|
||||
variable_pool.add([var.selector[0], var.selector[1]], var)
|
||||
|
@@ -13,7 +13,7 @@ from sqlalchemy.sql.expression import and_, or_
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.variables import Segment, StringSegment, Variable
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
@@ -147,7 +147,7 @@ class WorkflowDraftVariableService:
|
||||
) -> list[WorkflowDraftVariable]:
|
||||
ors = []
|
||||
for selector in selectors:
|
||||
assert len(selector) >= MIN_SELECTORS_LENGTH, f"Invalid selector to get: {selector}"
|
||||
assert len(selector) >= SELECTORS_LENGTH, f"Invalid selector to get: {selector}"
|
||||
node_id, name = selector[:2]
|
||||
ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name))
|
||||
|
||||
@@ -608,7 +608,7 @@ class DraftVariableSaver:
|
||||
|
||||
for item in updated_variables:
|
||||
selector = item.selector
|
||||
if len(selector) < MIN_SELECTORS_LENGTH:
|
||||
if len(selector) < SELECTORS_LENGTH:
|
||||
raise Exception("selector too short")
|
||||
# NOTE(QuantumGhost): only the following two kinds of variable could be updated by
|
||||
# VariableAssigner: ConversationVariable and iteration variable.
|
||||
|
@@ -69,8 +69,12 @@ def test_get_file_attribute(pool, file):
|
||||
|
||||
|
||||
def test_use_long_selector(pool):
|
||||
pool.add(("node_1", "part_1", "part_2"), StringSegment(value="test_value"))
|
||||
# The add method now only accepts 2-element selectors (node_id, variable_name)
|
||||
# Store nested data as an ObjectSegment instead
|
||||
nested_data = {"part_2": "test_value"}
|
||||
pool.add(("node_1", "part_1"), ObjectSegment(value=nested_data))
|
||||
|
||||
# The get method supports longer selectors for nested access
|
||||
result = pool.get(("node_1", "part_1", "part_2"))
|
||||
assert result is not None
|
||||
assert result.value == "test_value"
|
||||
@@ -280,8 +284,10 @@ class TestVariablePoolSerialization:
|
||||
pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file]))
|
||||
pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}]))
|
||||
|
||||
# Add nested variables
|
||||
pool.add((self._NODE3_ID, "nested", "deep", "var"), StringSegment(value="deep_value"))
|
||||
# Add nested variables as ObjectSegment
|
||||
# The add method only accepts 2-element selectors
|
||||
nested_obj = {"deep": {"var": "deep_value"}}
|
||||
pool.add((self._NODE3_ID, "nested"), ObjectSegment(value=nested_obj))
|
||||
|
||||
def test_system_variables(self):
|
||||
sys_vars = SystemVariable(
|
||||
|
@@ -1,148 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from core.variables.segments import ObjectSegment, StringSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.utils.variable_utils import append_variables_recursively
|
||||
|
||||
|
||||
class TestAppendVariablesRecursively:
|
||||
"""Test cases for append_variables_recursively function"""
|
||||
|
||||
def test_append_simple_dict_value(self):
|
||||
"""Test appending a simple dictionary value"""
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["output"]
|
||||
variable_value = {"name": "John", "age": 30}
|
||||
|
||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
||||
|
||||
# Check that the main variable is added
|
||||
main_var = pool.get([node_id] + variable_key_list)
|
||||
assert main_var is not None
|
||||
assert main_var.value == variable_value
|
||||
|
||||
# Check that nested variables are added recursively
|
||||
name_var = pool.get([node_id] + variable_key_list + ["name"])
|
||||
assert name_var is not None
|
||||
assert name_var.value == "John"
|
||||
|
||||
age_var = pool.get([node_id] + variable_key_list + ["age"])
|
||||
assert age_var is not None
|
||||
assert age_var.value == 30
|
||||
|
||||
def test_append_object_segment_value(self):
|
||||
"""Test appending an ObjectSegment value"""
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["result"]
|
||||
|
||||
# Create an ObjectSegment
|
||||
obj_data = {"status": "success", "code": 200}
|
||||
variable_value = ObjectSegment(value=obj_data)
|
||||
|
||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
||||
|
||||
# Check that the main variable is added
|
||||
main_var = pool.get([node_id] + variable_key_list)
|
||||
assert main_var is not None
|
||||
assert isinstance(main_var, ObjectSegment)
|
||||
assert main_var.value == obj_data
|
||||
|
||||
# Check that nested variables are added recursively
|
||||
status_var = pool.get([node_id] + variable_key_list + ["status"])
|
||||
assert status_var is not None
|
||||
assert status_var.value == "success"
|
||||
|
||||
code_var = pool.get([node_id] + variable_key_list + ["code"])
|
||||
assert code_var is not None
|
||||
assert code_var.value == 200
|
||||
|
||||
def test_append_nested_dict_value(self):
|
||||
"""Test appending a nested dictionary value"""
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["data"]
|
||||
|
||||
variable_value = {
|
||||
"user": {
|
||||
"profile": {"name": "Alice", "email": "alice@example.com"},
|
||||
"settings": {"theme": "dark", "notifications": True},
|
||||
},
|
||||
"metadata": {"version": "1.0", "timestamp": 1234567890},
|
||||
}
|
||||
|
||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
||||
|
||||
# Check deeply nested variables
|
||||
name_var = pool.get([node_id] + variable_key_list + ["user", "profile", "name"])
|
||||
assert name_var is not None
|
||||
assert name_var.value == "Alice"
|
||||
|
||||
email_var = pool.get([node_id] + variable_key_list + ["user", "profile", "email"])
|
||||
assert email_var is not None
|
||||
assert email_var.value == "alice@example.com"
|
||||
|
||||
theme_var = pool.get([node_id] + variable_key_list + ["user", "settings", "theme"])
|
||||
assert theme_var is not None
|
||||
assert theme_var.value == "dark"
|
||||
|
||||
notifications_var = pool.get([node_id] + variable_key_list + ["user", "settings", "notifications"])
|
||||
assert notifications_var is not None
|
||||
assert notifications_var.value == 1 # Boolean True is converted to integer 1
|
||||
|
||||
version_var = pool.get([node_id] + variable_key_list + ["metadata", "version"])
|
||||
assert version_var is not None
|
||||
assert version_var.value == "1.0"
|
||||
|
||||
def test_append_non_dict_value(self):
|
||||
"""Test appending a non-dictionary value (should not recurse)"""
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["simple"]
|
||||
variable_value = "simple_string"
|
||||
|
||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
||||
|
||||
# Check that only the main variable is added
|
||||
main_var = pool.get([node_id] + variable_key_list)
|
||||
assert main_var is not None
|
||||
assert main_var.value == variable_value
|
||||
|
||||
# Ensure no additional variables are created
|
||||
assert len(pool.variable_dictionary[node_id]) == 1
|
||||
|
||||
def test_append_segment_non_object_value(self):
|
||||
"""Test appending a Segment that is not ObjectSegment (should not recurse)"""
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["text"]
|
||||
variable_value = StringSegment(value="Hello World")
|
||||
|
||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
||||
|
||||
# Check that only the main variable is added
|
||||
main_var = pool.get([node_id] + variable_key_list)
|
||||
assert main_var is not None
|
||||
assert isinstance(main_var, StringSegment)
|
||||
assert main_var.value == "Hello World"
|
||||
|
||||
# Ensure no additional variables are created
|
||||
assert len(pool.variable_dictionary[node_id]) == 1
|
||||
|
||||
def test_append_empty_dict_value(self):
|
||||
"""Test appending an empty dictionary value"""
|
||||
pool = VariablePool.empty()
|
||||
node_id = "test_node"
|
||||
variable_key_list = ["empty"]
|
||||
variable_value: dict[str, Any] = {}
|
||||
|
||||
append_variables_recursively(pool, node_id, variable_key_list, variable_value)
|
||||
|
||||
# Check that the main variable is added
|
||||
main_var = pool.get([node_id] + variable_key_list)
|
||||
assert main_var is not None
|
||||
assert main_var.value == {}
|
||||
|
||||
# Ensure only the main variable is created (no recursion for empty dict)
|
||||
assert len(pool.variable_dictionary[node_id]) == 1
|
Reference in New Issue
Block a user