From 268da3133282b20f01d3bd5a46e3f69210aab399 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Wed, 25 Jun 2025 12:39:22 +0800 Subject: [PATCH] fix(api): adding variable to variable pool recursively while loading draft variables. (#21478) This PR fix the issue that `ObjectSegment` are not recursively added to the draft variable pool while loading draft variables from database. It also fixes an issue about loading variables with more than two elements in the its selector. Enhances #19735. Closes #21477. --- .../workflow/graph_engine/graph_engine.py | 17 +- api/core/workflow/utils/variable_utils.py | 28 ++++ api/core/workflow/variable_loader.py | 7 +- .../workflow_draft_variable_service.py | 3 +- .../workflow/utils/test_variable_utils.py | 148 ++++++++++++++++++ 5 files changed, 191 insertions(+), 12 deletions(-) create mode 100644 api/core/workflow/utils/variable_utils.py create mode 100644 api/tests/unit_tests/core/workflow/utils/test_variable_utils.py diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 2809afad0..61a7a2665 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -53,6 +53,7 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.utils import variable_utils from libs.flask_utils import preserve_flask_contexts from models.enums import UserFrom from models.workflow import WorkflowType @@ -856,16 +857,12 @@ class GraphEngine: :param variable_value: variable value :return: """ - self.graph_runtime_state.variable_pool.add([node_id] + variable_key_list, variable_value) - - # if variable_value is a dict, then recursively append variables - if isinstance(variable_value, dict): - for key, value in variable_value.items(): - # construct new key list - new_key_list = variable_key_list + [key] - self._append_variables_recursively( - node_id=node_id, variable_key_list=new_key_list, variable_value=value - ) + variable_utils.append_variables_recursively( + self.graph_runtime_state.variable_pool, + node_id, + variable_key_list, + variable_value, + ) def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: """ diff --git a/api/core/workflow/utils/variable_utils.py b/api/core/workflow/utils/variable_utils.py new file mode 100644 index 000000000..e68d990b6 --- /dev/null +++ b/api/core/workflow/utils/variable_utils.py @@ -0,0 +1,28 @@ +from core.variables.segments import ObjectSegment, Segment +from core.workflow.entities.variable_pool import VariablePool, VariableValue + + +def append_variables_recursively( + pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue | Segment +): + """ + Append variables recursively + :param node_id: node id + :param variable_key_list: variable key list + :param variable_value: variable value + :return: + """ + pool.add([node_id] + variable_key_list, variable_value) + + # if variable_value is a dict, then recursively append variables + if isinstance(variable_value, ObjectSegment): + variable_dict = variable_value.value + elif isinstance(variable_value, dict): + variable_dict = variable_value + else: + return + + for key, value in variable_dict.items(): + # construct new key list + new_key_list = variable_key_list + [key] + append_variables_recursively(pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value) diff --git a/api/core/workflow/variable_loader.py b/api/core/workflow/variable_loader.py index 4842ee00a..1e13871d0 100644 --- a/api/core/workflow/variable_loader.py +++ b/api/core/workflow/variable_loader.py @@ -3,7 +3,9 @@ from collections.abc import Mapping, Sequence from typing import Any, Protocol from core.variables import Variable +from core.variables.consts import MIN_SELECTORS_LENGTH from core.workflow.entities.variable_pool import VariablePool +from core.workflow.utils import variable_utils class VariableLoader(Protocol): @@ -76,4 +78,7 @@ def load_into_variable_pool( variables_to_load.append(list(selector)) loaded = variable_loader.load_variables(variables_to_load) for var in loaded: - variable_pool.add(var.selector, var) + assert len(var.selector) >= MIN_SELECTORS_LENGTH, f"Invalid variable {var}" + variable_utils.append_variables_recursively( + variable_pool, node_id=var.selector[0], variable_key_list=list(var.selector[1:]), variable_value=var + ) diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index cd30440b4..164693c2e 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -129,7 +129,8 @@ class WorkflowDraftVariableService: ) -> list[WorkflowDraftVariable]: ors = [] for selector in selectors: - node_id, name = selector + assert len(selector) >= MIN_SELECTORS_LENGTH, f"Invalid selector to get: {selector}" + node_id, name = selector[:2] ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name)) # NOTE(QuantumGhost): Although the number of `or` expressions may be large, as long as diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py b/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py new file mode 100644 index 000000000..f1cb937bb --- /dev/null +++ b/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py @@ -0,0 +1,148 @@ +from typing import Any + +from core.variables.segments import ObjectSegment, StringSegment +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.utils.variable_utils import append_variables_recursively + + +class TestAppendVariablesRecursively: + """Test cases for append_variables_recursively function""" + + def test_append_simple_dict_value(self): + """Test appending a simple dictionary value""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["output"] + variable_value = {"name": "John", "age": 30} + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert main_var.value == variable_value + + # Check that nested variables are added recursively + name_var = pool.get([node_id] + variable_key_list + ["name"]) + assert name_var is not None + assert name_var.value == "John" + + age_var = pool.get([node_id] + variable_key_list + ["age"]) + assert age_var is not None + assert age_var.value == 30 + + def test_append_object_segment_value(self): + """Test appending an ObjectSegment value""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["result"] + + # Create an ObjectSegment + obj_data = {"status": "success", "code": 200} + variable_value = ObjectSegment(value=obj_data) + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert isinstance(main_var, ObjectSegment) + assert main_var.value == obj_data + + # Check that nested variables are added recursively + status_var = pool.get([node_id] + variable_key_list + ["status"]) + assert status_var is not None + assert status_var.value == "success" + + code_var = pool.get([node_id] + variable_key_list + ["code"]) + assert code_var is not None + assert code_var.value == 200 + + def test_append_nested_dict_value(self): + """Test appending a nested dictionary value""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["data"] + + variable_value = { + "user": { + "profile": {"name": "Alice", "email": "alice@example.com"}, + "settings": {"theme": "dark", "notifications": True}, + }, + "metadata": {"version": "1.0", "timestamp": 1234567890}, + } + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check deeply nested variables + name_var = pool.get([node_id] + variable_key_list + ["user", "profile", "name"]) + assert name_var is not None + assert name_var.value == "Alice" + + email_var = pool.get([node_id] + variable_key_list + ["user", "profile", "email"]) + assert email_var is not None + assert email_var.value == "alice@example.com" + + theme_var = pool.get([node_id] + variable_key_list + ["user", "settings", "theme"]) + assert theme_var is not None + assert theme_var.value == "dark" + + notifications_var = pool.get([node_id] + variable_key_list + ["user", "settings", "notifications"]) + assert notifications_var is not None + assert notifications_var.value == 1 # Boolean True is converted to integer 1 + + version_var = pool.get([node_id] + variable_key_list + ["metadata", "version"]) + assert version_var is not None + assert version_var.value == "1.0" + + def test_append_non_dict_value(self): + """Test appending a non-dictionary value (should not recurse)""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["simple"] + variable_value = "simple_string" + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that only the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert main_var.value == variable_value + + # Ensure no additional variables are created + assert len(pool.variable_dictionary[node_id]) == 1 + + def test_append_segment_non_object_value(self): + """Test appending a Segment that is not ObjectSegment (should not recurse)""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["text"] + variable_value = StringSegment(value="Hello World") + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that only the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert isinstance(main_var, StringSegment) + assert main_var.value == "Hello World" + + # Ensure no additional variables are created + assert len(pool.variable_dictionary[node_id]) == 1 + + def test_append_empty_dict_value(self): + """Test appending an empty dictionary value""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["empty"] + variable_value: dict[str, Any] = {} + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert main_var.value == {} + + # Ensure only the main variable is created (no recursion for empty dict) + assert len(pool.variable_dictionary[node_id]) == 1