diff --git a/api/core/variables/consts.py b/api/core/variables/consts.py index 03b277d61..8f3f78f74 100644 --- a/api/core/variables/consts.py +++ b/api/core/variables/consts.py @@ -4,4 +4,4 @@ # # If the selector length is more than 2, the remaining parts are the keys / indexes paths used # to extract part of the variable value. -MIN_SELECTORS_LENGTH = 2 +SELECTORS_LENGTH = 2 diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index fbb8df6b0..fb0794844 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -7,8 +7,8 @@ from pydantic import BaseModel, Field from core.file import File, FileAttribute, file_manager from core.variables import Segment, SegmentGroup, Variable -from core.variables.consts import MIN_SELECTORS_LENGTH -from core.variables.segments import FileSegment, NoneSegment +from core.variables.consts import SELECTORS_LENGTH +from core.variables.segments import FileSegment, ObjectSegment from core.variables.variables import VariableUnion from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.system_variable import SystemVariable @@ -24,7 +24,7 @@ class VariablePool(BaseModel): # The first element of the selector is the node id, it's the first-level key in the dictionary. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. - variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field( + variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field( description="Variables mapping", default=defaultdict(dict), ) @@ -36,6 +36,7 @@ class VariablePool(BaseModel): ) system_variables: SystemVariable = Field( description="System variables", + default_factory=SystemVariable.empty, ) environment_variables: Sequence[VariableUnion] = Field( description="Environment variables.", @@ -58,23 +59,29 @@ class VariablePool(BaseModel): def add(self, selector: Sequence[str], value: Any, /) -> None: """ - Adds a variable to the variable pool. + Add a variable to the variable pool. - NOTE: You should not add a non-Segment value to the variable pool - even if it is allowed now. + This method accepts a selector path and a value, converting the value + to a Variable object if necessary before storing it in the pool. Args: - selector (Sequence[str]): The selector for the variable. - value (VariableValue): The value of the variable. + selector: A two-element sequence containing [node_id, variable_name]. + The selector must have exactly 2 elements to be valid. + value: The value to store. Can be a Variable, Segment, or any value + that can be converted to a Segment (str, int, float, dict, list, File). Raises: - ValueError: If the selector is invalid. + ValueError: If selector length is not exactly 2 elements. - Returns: - None + Note: + While non-Segment values are currently accepted and automatically + converted, it's recommended to pass Segment or Variable objects directly. """ - if len(selector) < MIN_SELECTORS_LENGTH: - raise ValueError("Invalid selector") + if len(selector) != SELECTORS_LENGTH: + raise ValueError( + f"Invalid selector: expected {SELECTORS_LENGTH} elements (node_id, variable_name), " + f"got {len(selector)} elements" + ) if isinstance(value, Variable): variable = value @@ -84,57 +91,85 @@ class VariablePool(BaseModel): segment = variable_factory.build_segment(value) variable = variable_factory.segment_to_variable(segment=segment, selector=selector) - key, hash_key = self._selector_to_keys(selector) + node_id, name = self._selector_to_keys(selector) # Based on the definition of `VariableUnion`, # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. - self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable) + self.variable_dictionary[node_id][name] = cast(VariableUnion, variable) @classmethod - def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]: - return selector[0], hash(tuple(selector[1:])) + def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]: + return selector[0], selector[1] def _has(self, selector: Sequence[str]) -> bool: - key, hash_key = self._selector_to_keys(selector) - if key not in self.variable_dictionary: + node_id, name = self._selector_to_keys(selector) + if node_id not in self.variable_dictionary: return False - if hash_key not in self.variable_dictionary[key]: + if name not in self.variable_dictionary[node_id]: return False return True def get(self, selector: Sequence[str], /) -> Segment | None: """ - Retrieves the value from the variable pool based on the given selector. + Retrieve a variable's value from the pool as a Segment. + + This method supports both simple selectors [node_id, variable_name] and + extended selectors that include attribute access for FileSegment and + ObjectSegment types. Args: - selector (Sequence[str]): The selector used to identify the variable. + selector: A sequence with at least 2 elements: + - [node_id, variable_name]: Returns the full segment + - [node_id, variable_name, attr, ...]: Returns a nested value + from FileSegment (e.g., 'url', 'name') or ObjectSegment Returns: - Any: The value associated with the given selector. + The Segment associated with the selector, or None if not found. + Returns None if selector has fewer than 2 elements. Raises: - ValueError: If the selector is invalid. + ValueError: If attempting to access an invalid FileAttribute. """ - if len(selector) < MIN_SELECTORS_LENGTH: + if len(selector) < SELECTORS_LENGTH: return None - key, hash_key = self._selector_to_keys(selector) - value: Segment | None = self.variable_dictionary[key].get(hash_key) + node_id, name = self._selector_to_keys(selector) + segment: Segment | None = self.variable_dictionary[node_id].get(name) - if value is None: - selector, attr = selector[:-1], selector[-1] + if segment is None: + return None + + if len(selector) == 2: + return segment + + if isinstance(segment, FileSegment): + attr = selector[2] # Python support `attr in FileAttribute` after 3.12 if attr not in {item.value for item in FileAttribute}: return None - value = self.get(selector) - if not isinstance(value, FileSegment | NoneSegment): - return None - if isinstance(value, FileSegment): - attr = FileAttribute(attr) - attr_value = file_manager.get_attr(file=value.value, attr=attr) - return variable_factory.build_segment(attr_value) - return value + attr = FileAttribute(attr) + attr_value = file_manager.get_attr(file=segment.value, attr=attr) + return variable_factory.build_segment(attr_value) - return value + # Navigate through nested attributes + result: Any = segment + for attr in selector[2:]: + result = self._extract_value(result) + result = self._get_nested_attribute(result, attr) + if result is None: + return None + + # Return result as Segment + return result if isinstance(result, Segment) else variable_factory.build_segment(result) + + def _extract_value(self, obj: Any) -> Any: + """Extract the actual value from an ObjectSegment.""" + return obj.value if isinstance(obj, ObjectSegment) else obj + + def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Any: + """Get a nested attribute from a dictionary-like object.""" + if not isinstance(obj, dict): + return None + return obj.get(attr) def remove(self, selector: Sequence[str], /): """ diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index ef13277e0..b9663d32f 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -15,7 +15,7 @@ from configs import dify_config from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult -from core.workflow.entities.variable_pool import VariablePool, VariableValue +from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager from core.workflow.graph_engine.entities.event import ( @@ -51,7 +51,6 @@ from core.workflow.nodes.base import BaseNode from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent -from core.workflow.utils import variable_utils from libs.flask_utils import preserve_flask_contexts from models.enums import UserFrom from models.workflow import WorkflowType @@ -701,11 +700,9 @@ class GraphEngine: route_node_state.status = RouteNodeState.Status.EXCEPTION if run_result.outputs: for variable_key, variable_value in run_result.outputs.items(): - # append variables to variable pool recursively - self._append_variables_recursively( - node_id=node.node_id, - variable_key_list=[variable_key], - variable_value=variable_value, + # Add variables to variable pool + self.graph_runtime_state.variable_pool.add( + [node.node_id, variable_key], variable_value ) yield NodeRunExceptionEvent( error=run_result.error or "System Error", @@ -758,11 +755,9 @@ class GraphEngine: # append node output variables to variable pool if run_result.outputs: for variable_key, variable_value in run_result.outputs.items(): - # append variables to variable pool recursively - self._append_variables_recursively( - node_id=node.node_id, - variable_key_list=[variable_key], - variable_value=variable_value, + # Add variables to variable pool + self.graph_runtime_state.variable_pool.add( + [node.node_id, variable_key], variable_value ) # When setting metadata, convert to dict first @@ -851,21 +846,6 @@ class GraphEngine: logger.exception("Node %s run failed", node.title) raise e - def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): - """ - Append variables recursively - :param node_id: node id - :param variable_key_list: variable key list - :param variable_value: variable value - :return: - """ - variable_utils.append_variables_recursively( - self.graph_runtime_state.variable_pool, - node_id, - variable_key_list, - variable_value, - ) - def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: """ Check timeout diff --git a/api/core/workflow/nodes/variable_assigner/common/helpers.py b/api/core/workflow/nodes/variable_assigner/common/helpers.py index 0d2822233..48deda724 100644 --- a/api/core/workflow/nodes/variable_assigner/common/helpers.py +++ b/api/core/workflow/nodes/variable_assigner/common/helpers.py @@ -4,7 +4,7 @@ from typing import Any, TypeVar from pydantic import BaseModel from core.variables import Segment -from core.variables.consts import MIN_SELECTORS_LENGTH +from core.variables.consts import SELECTORS_LENGTH from core.variables.types import SegmentType # Use double underscore (`__`) prefix for internal variables @@ -23,7 +23,7 @@ _T = TypeVar("_T", bound=MutableMapping[str, Any]) def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable: - if len(selector) < MIN_SELECTORS_LENGTH: + if len(selector) < SELECTORS_LENGTH: raise Exception("selector too short") node_id, var_name = selector[:2] return UpdatedVariable( diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index c0215cae7..00ee921ce 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -4,7 +4,7 @@ from typing import Any, Optional, cast from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import SegmentType, Variable -from core.variables.consts import MIN_SELECTORS_LENGTH +from core.variables.consts import SELECTORS_LENGTH from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities.node_entities import NodeRunResult @@ -46,7 +46,7 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_ selector = item.value if not isinstance(selector, list): raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}") - if len(selector) < MIN_SELECTORS_LENGTH: + if len(selector) < SELECTORS_LENGTH: raise InvalidDataError(f"selector too short, {node_id=}, {item=}") selector_str = ".".join(selector) key = f"{node_id}.#{selector_str}#" diff --git a/api/core/workflow/utils/variable_utils.py b/api/core/workflow/utils/variable_utils.py deleted file mode 100644 index 868868315..000000000 --- a/api/core/workflow/utils/variable_utils.py +++ /dev/null @@ -1,29 +0,0 @@ -from core.variables.segments import ObjectSegment, Segment -from core.workflow.entities.variable_pool import VariablePool, VariableValue - - -def append_variables_recursively( - pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue | Segment -): - """ - Append variables recursively - :param pool: variable pool to append variables to - :param node_id: node id - :param variable_key_list: variable key list - :param variable_value: variable value - :return: - """ - pool.add([node_id] + variable_key_list, variable_value) - - # if variable_value is a dict, then recursively append variables - if isinstance(variable_value, ObjectSegment): - variable_dict = variable_value.value - elif isinstance(variable_value, dict): - variable_dict = variable_value - else: - return - - for key, value in variable_dict.items(): - # construct new key list - new_key_list = variable_key_list + [key] - append_variables_recursively(pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value) diff --git a/api/core/workflow/variable_loader.py b/api/core/workflow/variable_loader.py index 1e13871d0..a35215855 100644 --- a/api/core/workflow/variable_loader.py +++ b/api/core/workflow/variable_loader.py @@ -3,9 +3,8 @@ from collections.abc import Mapping, Sequence from typing import Any, Protocol from core.variables import Variable -from core.variables.consts import MIN_SELECTORS_LENGTH +from core.variables.consts import SELECTORS_LENGTH from core.workflow.entities.variable_pool import VariablePool -from core.workflow.utils import variable_utils class VariableLoader(Protocol): @@ -78,7 +77,7 @@ def load_into_variable_pool( variables_to_load.append(list(selector)) loaded = variable_loader.load_variables(variables_to_load) for var in loaded: - assert len(var.selector) >= MIN_SELECTORS_LENGTH, f"Invalid variable {var}" - variable_utils.append_variables_recursively( - variable_pool, node_id=var.selector[0], variable_key_list=list(var.selector[1:]), variable_value=var - ) + assert len(var.selector) >= SELECTORS_LENGTH, f"Invalid variable {var}" + # Add variable directly to the pool + # The variable pool expects 2-element selectors [node_id, variable_name] + variable_pool.add([var.selector[0], var.selector[1]], var) diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 6bbb3bca0..b52f4924b 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -13,7 +13,7 @@ from sqlalchemy.sql.expression import and_, or_ from core.app.entities.app_invoke_entities import InvokeFrom from core.file.models import File from core.variables import Segment, StringSegment, Variable -from core.variables.consts import MIN_SELECTORS_LENGTH +from core.variables.consts import SELECTORS_LENGTH from core.variables.segments import ArrayFileSegment, FileSegment from core.variables.types import SegmentType from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID @@ -147,7 +147,7 @@ class WorkflowDraftVariableService: ) -> list[WorkflowDraftVariable]: ors = [] for selector in selectors: - assert len(selector) >= MIN_SELECTORS_LENGTH, f"Invalid selector to get: {selector}" + assert len(selector) >= SELECTORS_LENGTH, f"Invalid selector to get: {selector}" node_id, name = selector[:2] ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name)) @@ -608,7 +608,7 @@ class DraftVariableSaver: for item in updated_variables: selector = item.selector - if len(selector) < MIN_SELECTORS_LENGTH: + if len(selector) < SELECTORS_LENGTH: raise Exception("selector too short") # NOTE(QuantumGhost): only the following two kinds of variable could be updated by # VariableAssigner: ConversationVariable and iteration variable. diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index c65b60cb4..c0330b944 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -69,8 +69,12 @@ def test_get_file_attribute(pool, file): def test_use_long_selector(pool): - pool.add(("node_1", "part_1", "part_2"), StringSegment(value="test_value")) + # The add method now only accepts 2-element selectors (node_id, variable_name) + # Store nested data as an ObjectSegment instead + nested_data = {"part_2": "test_value"} + pool.add(("node_1", "part_1"), ObjectSegment(value=nested_data)) + # The get method supports longer selectors for nested access result = pool.get(("node_1", "part_1", "part_2")) assert result is not None assert result.value == "test_value" @@ -280,8 +284,10 @@ class TestVariablePoolSerialization: pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file])) pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}])) - # Add nested variables - pool.add((self._NODE3_ID, "nested", "deep", "var"), StringSegment(value="deep_value")) + # Add nested variables as ObjectSegment + # The add method only accepts 2-element selectors + nested_obj = {"deep": {"var": "deep_value"}} + pool.add((self._NODE3_ID, "nested"), ObjectSegment(value=nested_obj)) def test_system_variables(self): sys_vars = SystemVariable( diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py b/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py deleted file mode 100644 index 54bf6558b..000000000 --- a/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py +++ /dev/null @@ -1,148 +0,0 @@ -from typing import Any - -from core.variables.segments import ObjectSegment, StringSegment -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.utils.variable_utils import append_variables_recursively - - -class TestAppendVariablesRecursively: - """Test cases for append_variables_recursively function""" - - def test_append_simple_dict_value(self): - """Test appending a simple dictionary value""" - pool = VariablePool.empty() - node_id = "test_node" - variable_key_list = ["output"] - variable_value = {"name": "John", "age": 30} - - append_variables_recursively(pool, node_id, variable_key_list, variable_value) - - # Check that the main variable is added - main_var = pool.get([node_id] + variable_key_list) - assert main_var is not None - assert main_var.value == variable_value - - # Check that nested variables are added recursively - name_var = pool.get([node_id] + variable_key_list + ["name"]) - assert name_var is not None - assert name_var.value == "John" - - age_var = pool.get([node_id] + variable_key_list + ["age"]) - assert age_var is not None - assert age_var.value == 30 - - def test_append_object_segment_value(self): - """Test appending an ObjectSegment value""" - pool = VariablePool.empty() - node_id = "test_node" - variable_key_list = ["result"] - - # Create an ObjectSegment - obj_data = {"status": "success", "code": 200} - variable_value = ObjectSegment(value=obj_data) - - append_variables_recursively(pool, node_id, variable_key_list, variable_value) - - # Check that the main variable is added - main_var = pool.get([node_id] + variable_key_list) - assert main_var is not None - assert isinstance(main_var, ObjectSegment) - assert main_var.value == obj_data - - # Check that nested variables are added recursively - status_var = pool.get([node_id] + variable_key_list + ["status"]) - assert status_var is not None - assert status_var.value == "success" - - code_var = pool.get([node_id] + variable_key_list + ["code"]) - assert code_var is not None - assert code_var.value == 200 - - def test_append_nested_dict_value(self): - """Test appending a nested dictionary value""" - pool = VariablePool.empty() - node_id = "test_node" - variable_key_list = ["data"] - - variable_value = { - "user": { - "profile": {"name": "Alice", "email": "alice@example.com"}, - "settings": {"theme": "dark", "notifications": True}, - }, - "metadata": {"version": "1.0", "timestamp": 1234567890}, - } - - append_variables_recursively(pool, node_id, variable_key_list, variable_value) - - # Check deeply nested variables - name_var = pool.get([node_id] + variable_key_list + ["user", "profile", "name"]) - assert name_var is not None - assert name_var.value == "Alice" - - email_var = pool.get([node_id] + variable_key_list + ["user", "profile", "email"]) - assert email_var is not None - assert email_var.value == "alice@example.com" - - theme_var = pool.get([node_id] + variable_key_list + ["user", "settings", "theme"]) - assert theme_var is not None - assert theme_var.value == "dark" - - notifications_var = pool.get([node_id] + variable_key_list + ["user", "settings", "notifications"]) - assert notifications_var is not None - assert notifications_var.value == 1 # Boolean True is converted to integer 1 - - version_var = pool.get([node_id] + variable_key_list + ["metadata", "version"]) - assert version_var is not None - assert version_var.value == "1.0" - - def test_append_non_dict_value(self): - """Test appending a non-dictionary value (should not recurse)""" - pool = VariablePool.empty() - node_id = "test_node" - variable_key_list = ["simple"] - variable_value = "simple_string" - - append_variables_recursively(pool, node_id, variable_key_list, variable_value) - - # Check that only the main variable is added - main_var = pool.get([node_id] + variable_key_list) - assert main_var is not None - assert main_var.value == variable_value - - # Ensure no additional variables are created - assert len(pool.variable_dictionary[node_id]) == 1 - - def test_append_segment_non_object_value(self): - """Test appending a Segment that is not ObjectSegment (should not recurse)""" - pool = VariablePool.empty() - node_id = "test_node" - variable_key_list = ["text"] - variable_value = StringSegment(value="Hello World") - - append_variables_recursively(pool, node_id, variable_key_list, variable_value) - - # Check that only the main variable is added - main_var = pool.get([node_id] + variable_key_list) - assert main_var is not None - assert isinstance(main_var, StringSegment) - assert main_var.value == "Hello World" - - # Ensure no additional variables are created - assert len(pool.variable_dictionary[node_id]) == 1 - - def test_append_empty_dict_value(self): - """Test appending an empty dictionary value""" - pool = VariablePool.empty() - node_id = "test_node" - variable_key_list = ["empty"] - variable_value: dict[str, Any] = {} - - append_variables_recursively(pool, node_id, variable_key_list, variable_value) - - # Check that the main variable is added - main_var = pool.get([node_id] + variable_key_list) - assert main_var is not None - assert main_var.value == {} - - # Ensure only the main variable is created (no recursion for empty dict) - assert len(pool.variable_dictionary[node_id]) == 1