feat: Persist Variables for Enhanced Debugging Workflow (#20699)

This pull request introduces a feature aimed at improving the debugging experience during workflow editing. With the addition of variable persistence, the system will automatically retain the output variables from previously executed nodes. These persisted variables can then be reused when debugging subsequent nodes, eliminating the need for repetitive manual input.

By streamlining this aspect of the workflow, the feature minimizes user errors and significantly reduces debugging effort, offering a smoother and more efficient experience.

Key highlights of this change:

- Automatic persistence of output variables for executed nodes.
- Reuse of persisted variables to simplify input steps for nodes requiring them (e.g., `code`, `template`, `variable_assigner`).
- Enhanced debugging experience with reduced friction.

Closes #19735.
This commit is contained in:
QuantumGhost
2025-06-24 09:05:29 +08:00
committed by GitHub
parent 3113350e51
commit 10b738a296
106 changed files with 6025 additions and 718 deletions

View File

@@ -0,0 +1,302 @@
import datetime
import uuid
from collections import OrderedDict
from typing import Any, NamedTuple
from flask_restful import marshal
from controllers.console.app.workflow_draft_variable import (
_WORKFLOW_DRAFT_VARIABLE_FIELDS,
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS,
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS,
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
)
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from factories.variable_factory import build_segment
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList
_TEST_APP_ID = "test_app_id"
_TEST_NODE_EXEC_ID = str(uuid.uuid4())
class TestWorkflowDraftVariableFields:
def test_conversation_variable(self):
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1)
)
conv_var.id = str(uuid.uuid4())
conv_var.visible = True
expected_without_value: OrderedDict[str, Any] = OrderedDict(
{
"id": str(conv_var.id),
"type": conv_var.get_variable_type().value,
"name": "conv_var",
"description": "",
"selector": [CONVERSATION_VARIABLE_NODE_ID, "conv_var"],
"value_type": "number",
"edited": False,
"visible": True,
}
)
assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
expected_with_value = expected_without_value.copy()
expected_with_value["value"] = 1
assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
def test_create_sys_variable(self):
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=_TEST_APP_ID,
name="sys_var",
value=build_segment("a"),
editable=True,
node_execution_id=_TEST_NODE_EXEC_ID,
)
sys_var.id = str(uuid.uuid4())
sys_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
sys_var.visible = True
expected_without_value = OrderedDict(
{
"id": str(sys_var.id),
"type": sys_var.get_variable_type().value,
"name": "sys_var",
"description": "",
"selector": [SYSTEM_VARIABLE_NODE_ID, "sys_var"],
"value_type": "string",
"edited": True,
"visible": True,
}
)
assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
expected_with_value = expected_without_value.copy()
expected_with_value["value"] = "a"
assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
def test_node_variable(self):
node_var = WorkflowDraftVariable.new_node_variable(
app_id=_TEST_APP_ID,
node_id="test_node",
name="node_var",
value=build_segment([1, "a"]),
visible=False,
node_execution_id=_TEST_NODE_EXEC_ID,
)
node_var.id = str(uuid.uuid4())
node_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
expected_without_value: OrderedDict[str, Any] = OrderedDict(
{
"id": str(node_var.id),
"type": node_var.get_variable_type().value,
"name": "node_var",
"description": "",
"selector": ["test_node", "node_var"],
"value_type": "array[any]",
"edited": True,
"visible": False,
}
)
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
expected_with_value = expected_without_value.copy()
expected_with_value["value"] = [1, "a"]
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
class TestWorkflowDraftVariableList:
def test_workflow_draft_variable_list(self):
class TestCase(NamedTuple):
name: str
var_list: WorkflowDraftVariableList
expected: dict
node_var = WorkflowDraftVariable.new_node_variable(
app_id=_TEST_APP_ID,
node_id="test_node",
name="test_var",
value=build_segment("a"),
visible=True,
node_execution_id=_TEST_NODE_EXEC_ID,
)
node_var.id = str(uuid.uuid4())
node_var_dict = OrderedDict(
{
"id": str(node_var.id),
"type": node_var.get_variable_type().value,
"name": "test_var",
"description": "",
"selector": ["test_node", "test_var"],
"value_type": "string",
"edited": False,
"visible": True,
}
)
cases = [
TestCase(
name="empty variable list",
var_list=WorkflowDraftVariableList(variables=[]),
expected=OrderedDict(
{
"items": [],
"total": None,
}
),
),
TestCase(
name="empty variable list with total",
var_list=WorkflowDraftVariableList(variables=[], total=10),
expected=OrderedDict(
{
"items": [],
"total": 10,
}
),
),
TestCase(
name="non-empty variable list",
var_list=WorkflowDraftVariableList(variables=[node_var], total=None),
expected=OrderedDict(
{
"items": [node_var_dict],
"total": None,
}
),
),
TestCase(
name="non-empty variable list with total",
var_list=WorkflowDraftVariableList(variables=[node_var], total=10),
expected=OrderedDict(
{
"items": [node_var_dict],
"total": 10,
}
),
),
]
for idx, case in enumerate(cases, 1):
assert marshal(case.var_list, _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) == case.expected, (
f"Test case {idx} failed, {case.name=}"
)
def test_workflow_node_variables_fields():
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1)
)
resp = marshal(WorkflowDraftVariableList(variables=[conv_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
assert isinstance(resp, dict)
assert len(resp["items"]) == 1
item_dict = resp["items"][0]
assert item_dict["name"] == "conv_var"
assert item_dict["value"] == 1
def test_workflow_file_variable_with_signed_url():
"""Test that File type variables include signed URLs in API responses."""
from core.file.enums import FileTransferMethod, FileType
from core.file.models import File
# Create a File object with LOCAL_FILE transfer method (which generates signed URLs)
test_file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test_upload_file_id",
filename="test.jpg",
extension=".jpg",
mime_type="image/jpeg",
size=12345,
)
# Create a WorkflowDraftVariable with the File
file_var = WorkflowDraftVariable.new_node_variable(
app_id=_TEST_APP_ID,
node_id="test_node",
name="file_var",
value=build_segment(test_file),
node_execution_id=_TEST_NODE_EXEC_ID,
)
# Marshal the variable using the API fields
resp = marshal(WorkflowDraftVariableList(variables=[file_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
# Verify the response structure
assert isinstance(resp, dict)
assert len(resp["items"]) == 1
item_dict = resp["items"][0]
assert item_dict["name"] == "file_var"
# Verify the value is a dict (File.to_dict() result) and contains expected fields
value = item_dict["value"]
assert isinstance(value, dict)
# Verify the File fields are preserved
assert value["id"] == test_file.id
assert value["filename"] == test_file.filename
assert value["type"] == test_file.type.value
assert value["transfer_method"] == test_file.transfer_method.value
assert value["size"] == test_file.size
# Verify the URL is present (it should be a signed URL for LOCAL_FILE transfer method)
remote_url = value["remote_url"]
assert remote_url is not None
assert isinstance(remote_url, str)
# For LOCAL_FILE, the URL should contain signature parameters
assert "timestamp=" in remote_url
assert "nonce=" in remote_url
assert "sign=" in remote_url
def test_workflow_file_variable_remote_url():
"""Test that File type variables with REMOTE_URL transfer method return the remote URL."""
from core.file.enums import FileTransferMethod, FileType
from core.file.models import File
# Create a File object with REMOTE_URL transfer method
test_file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/test.jpg",
filename="test.jpg",
extension=".jpg",
mime_type="image/jpeg",
size=12345,
)
# Create a WorkflowDraftVariable with the File
file_var = WorkflowDraftVariable.new_node_variable(
app_id=_TEST_APP_ID,
node_id="test_node",
name="file_var",
value=build_segment(test_file),
node_execution_id=_TEST_NODE_EXEC_ID,
)
# Marshal the variable using the API fields
resp = marshal(WorkflowDraftVariableList(variables=[file_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
# Verify the response structure
assert isinstance(resp, dict)
assert len(resp["items"]) == 1
item_dict = resp["items"][0]
assert item_dict["name"] == "file_var"
# Verify the value is a dict (File.to_dict() result) and contains expected fields
value = item_dict["value"]
assert isinstance(value, dict)
remote_url = value["remote_url"]
# For REMOTE_URL, the URL should be the original remote URL
assert remote_url == test_file.remote_url

View File

@@ -1,165 +0,0 @@
from uuid import uuid4
import pytest
from core.variables import (
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FloatVariable,
IntegerVariable,
ObjectSegment,
SecretVariable,
StringVariable,
)
from core.variables.exc import VariableError
from core.variables.segments import ArrayAnySegment
from factories import variable_factory
def test_string_variable():
test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, StringVariable)
def test_integer_variable():
test_data = {"value_type": "number", "name": "test_int", "value": 42}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, IntegerVariable)
def test_float_variable():
test_data = {"value_type": "number", "name": "test_float", "value": 3.14}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, FloatVariable)
def test_secret_variable():
test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, SecretVariable)
def test_invalid_value_type():
test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"}
with pytest.raises(VariableError):
variable_factory.build_conversation_variable_from_mapping(test_data)
def test_build_a_blank_string():
result = variable_factory.build_conversation_variable_from_mapping(
{
"value_type": "string",
"name": "blank",
"value": "",
}
)
assert isinstance(result, StringVariable)
assert result.value == ""
def test_build_a_object_variable_with_none_value():
var = variable_factory.build_segment(
{
"key1": None,
}
)
assert isinstance(var, ObjectSegment)
assert var.value["key1"] is None
def test_object_variable():
mapping = {
"id": str(uuid4()),
"value_type": "object",
"name": "test_object",
"description": "Description of the variable.",
"value": {
"key1": "text",
"key2": 2,
},
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ObjectSegment)
assert isinstance(variable.value["key1"], str)
assert isinstance(variable.value["key2"], int)
def test_array_string_variable():
mapping = {
"id": str(uuid4()),
"value_type": "array[string]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
"text",
"text",
],
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ArrayStringVariable)
assert isinstance(variable.value[0], str)
assert isinstance(variable.value[1], str)
def test_array_number_variable():
mapping = {
"id": str(uuid4()),
"value_type": "array[number]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
1,
2.0,
],
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ArrayNumberVariable)
assert isinstance(variable.value[0], int)
assert isinstance(variable.value[1], float)
def test_array_object_variable():
mapping = {
"id": str(uuid4()),
"value_type": "array[object]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
{
"key1": "text",
"key2": 1,
},
{
"key1": "text",
"key2": 1,
},
],
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ArrayObjectVariable)
assert isinstance(variable.value[0], dict)
assert isinstance(variable.value[1], dict)
assert isinstance(variable.value[0]["key1"], str)
assert isinstance(variable.value[0]["key2"], int)
assert isinstance(variable.value[1]["key1"], str)
assert isinstance(variable.value[1]["key2"], int)
def test_variable_cannot_large_than_200_kb():
with pytest.raises(VariableError):
variable_factory.build_conversation_variable_from_mapping(
{
"id": str(uuid4()),
"value_type": "string",
"name": "test_text",
"value": "a" * 1024 * 201,
}
)
def test_array_none_variable():
var = variable_factory.build_segment([None, None, None, None])
assert isinstance(var, ArrayAnySegment)
assert var.value == [None, None, None, None]

View File

@@ -0,0 +1,25 @@
from core.file import File, FileTransferMethod, FileType
def test_file():
file = File(
id="test-file",
tenant_id="test-tenant-id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="test-related-id",
filename="image.png",
extension=".png",
mime_type="image/png",
size=67,
storage_key="test-storage-key",
url="https://example.com/image.png",
)
assert file.tenant_id == "test-tenant-id"
assert file.type == FileType.IMAGE
assert file.transfer_method == FileTransferMethod.TOOL_FILE
assert file.related_id == "test-related-id"
assert file.filename == "image.png"
assert file.extension == ".png"
assert file.mime_type == "image/png"
assert file.size == 67

View File

@@ -3,6 +3,7 @@ import uuid
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables.segments import ArrayAnySegment, ArrayStringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@@ -197,7 +198,7 @@ def test_run():
count += 1
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
assert count == 20
@@ -413,7 +414,7 @@ def test_run_parallel():
count += 1
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
assert count == 32
@@ -654,7 +655,7 @@ def test_iteration_run_in_parallel_mode():
parallel_arr.append(item)
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
assert count == 32
for item in sequential_result:
@@ -662,7 +663,7 @@ def test_iteration_run_in_parallel_mode():
count += 1
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
assert count == 64
@@ -846,7 +847,7 @@ def test_iteration_run_error_handle():
count += 1
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": [None, None]}
assert item.run_result.outputs == {"output": ArrayAnySegment(value=[None, None])}
assert count == 14
# execute remove abnormal output
@@ -857,5 +858,5 @@ def test_iteration_run_error_handle():
count += 1
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": []}
assert item.run_result.outputs == {"output": ArrayAnySegment(value=[])}
assert count == 14

View File

@@ -7,6 +7,7 @@ from docx.oxml.text.paragraph import CT_P
from core.file import File, FileTransferMethod
from core.variables import ArrayFileSegment
from core.variables.segments import ArrayStringSegment
from core.variables.variables import StringVariable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@@ -69,7 +70,13 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s
@pytest.mark.parametrize(
("mime_type", "file_content", "expected_text", "transfer_method", "extension"),
[
("text/plain", b"Hello, world!", ["Hello, world!"], FileTransferMethod.LOCAL_FILE, ".txt"),
(
"text/plain",
b"Hello, world!",
["Hello, world!"],
FileTransferMethod.LOCAL_FILE,
".txt",
),
(
"application/pdf",
b"%PDF-1.5\n%Test PDF content",
@@ -84,7 +91,13 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s
FileTransferMethod.REMOTE_URL,
"",
),
("text/plain", b"Remote content", ["Remote content"], FileTransferMethod.REMOTE_URL, None),
(
"text/plain",
b"Remote content",
["Remote content"],
FileTransferMethod.REMOTE_URL,
None,
),
],
)
def test_run_extract_text(
@@ -131,7 +144,7 @@ def test_run_extract_text(
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error
assert result.outputs is not None
assert result.outputs["text"] == expected_text
assert result.outputs["text"] == ArrayStringSegment(value=expected_text)
if transfer_method == FileTransferMethod.REMOTE_URL:
mock_ssrf_proxy_get.assert_called_once_with("https://example.com/file.txt")

View File

@@ -115,7 +115,7 @@ def test_filter_files_by_type(list_operator_node):
},
]
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
for expected_file, result_file in zip(expected_files, result.outputs["result"]):
for expected_file, result_file in zip(expected_files, result.outputs["result"].value):
assert expected_file["filename"] == result_file.filename
assert expected_file["type"] == result_file.type
assert expected_file["tenant_id"] == result_file.tenant_id

View File

@@ -5,6 +5,7 @@ from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import ArrayStringVariable, StringVariable
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
@@ -63,10 +64,11 @@ def test_overwrite_string_variable():
name="test_string_variable",
value="the second value",
)
conversation_id = str(uuid.uuid4())
# construct variable pool
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -77,6 +79,9 @@ def test_overwrite_string_variable():
input_variable,
)
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
@@ -91,11 +96,20 @@ def test_overwrite_string_variable():
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
},
conv_var_updater_factory=mock_conv_var_updater_factory,
)
with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run:
list(node.run())
mock_run.assert_called_once()
list(node.run())
expected_var = StringVariable(
id=conversation_variable.id,
name=conversation_variable.name,
description=conversation_variable.description,
selector=conversation_variable.selector,
value_type=conversation_variable.value_type,
value=input_variable.value,
)
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
mock_conv_var_updater.flush.assert_called_once()
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
@@ -148,9 +162,10 @@ def test_append_variable_to_array():
name="test_string_variable",
value="the second value",
)
conversation_id = str(uuid.uuid4())
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -160,6 +175,9 @@ def test_append_variable_to_array():
input_variable,
)
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
@@ -174,11 +192,22 @@ def test_append_variable_to_array():
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
},
conv_var_updater_factory=mock_conv_var_updater_factory,
)
with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run:
list(node.run())
mock_run.assert_called_once()
list(node.run())
expected_value = list(conversation_variable.value)
expected_value.append(input_variable.value)
expected_var = ArrayStringVariable(
id=conversation_variable.id,
name=conversation_variable.name,
description=conversation_variable.description,
selector=conversation_variable.selector,
value_type=conversation_variable.value_type,
value=expected_value,
)
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
mock_conv_var_updater.flush.assert_called_once()
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
@@ -225,13 +254,17 @@ def test_clear_array():
value=["the first value"],
)
conversation_id = str(uuid.uuid4())
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
@@ -246,11 +279,20 @@ def test_clear_array():
"input_variable_selector": [],
},
},
conv_var_updater_factory=mock_conv_var_updater_factory,
)
with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run:
list(node.run())
mock_run.assert_called_once()
list(node.run())
expected_var = ArrayStringVariable(
id=conversation_variable.id,
name=conversation_variable.name,
description=conversation_variable.description,
selector=conversation_variable.selector,
value_type=conversation_variable.value_type,
value=[],
)
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
mock_conv_var_updater.flush.assert_called_once()
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None

View File

@@ -1,8 +1,12 @@
import pytest
from pydantic import ValidationError
from core.file import File, FileTransferMethod, FileType
from core.variables import FileSegment, StringSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from factories.variable_factory import build_segment, segment_to_variable
@pytest.fixture
@@ -44,3 +48,38 @@ def test_use_long_selector(pool):
result = pool.get(("node_1", "part_1", "part_2"))
assert result is not None
assert result.value == "test_value"
class TestVariablePool:
def test_constructor(self):
pool = VariablePool()
pool = VariablePool(
variable_dictionary={},
user_inputs={},
system_variables={},
environment_variables=[],
conversation_variables=[],
)
pool = VariablePool(
user_inputs={"key": "value"},
system_variables={SystemVariableKey.WORKFLOW_ID: "test_workflow_id"},
environment_variables=[
segment_to_variable(
segment=build_segment(1),
selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_var_1"],
name="env_var_1",
)
],
conversation_variables=[
segment_to_variable(
segment=build_segment("1"),
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var_1"],
name="conv_var_1",
)
],
)
def test_constructor_with_invalid_system_variable_key(self):
with pytest.raises(ValidationError):
VariablePool(system_variables={"invalid_key": "value"}) # type: ignore

View File

@@ -1,22 +1,10 @@
from core.variables import SecretVariable
import dataclasses
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.utils import variable_template_parser
def test_extract_selectors_from_template():
variable_pool = VariablePool(
system_variables={
SystemVariableKey("user_id"): "fake-user-id",
},
user_inputs={},
environment_variables=[
SecretVariable(name="secret_key", value="fake-secret-key"),
],
conversation_variables=[],
)
variable_pool.add(("node_id", "custom_query"), "fake-user-query")
template = (
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
)
@@ -26,3 +14,35 @@ def test_extract_selectors_from_template():
VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]),
VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]),
]
def test_invalid_references():
@dataclasses.dataclass
class TestCase:
name: str
template: str
cases = [
TestCase(
name="lack of closing brace",
template="Hello, {{#sys.user_id#",
),
TestCase(
name="lack of opening brace",
template="Hello, #sys.user_id#}}",
),
TestCase(
name="lack selector name",
template="Hello, {{#sys#}}",
),
TestCase(
name="empty node name part",
template="Hello, {{#.user_id#}}",
),
]
for idx, c in enumerate(cases, 1):
fail_msg = f"Test case {c.name} failed, index={idx}"
selectors = variable_template_parser.extract_selectors_from_template(c.template)
assert selectors == [], fail_msg
parser = variable_template_parser.VariableTemplateParser(c.template)
assert parser.extract_variable_selectors() == [], fail_msg

View File

@@ -0,0 +1,865 @@
import math
from dataclasses import dataclass
from typing import Any
from uuid import uuid4
import pytest
from hypothesis import given
from hypothesis import strategies as st
from core.file import File, FileTransferMethod, FileType
from core.variables import (
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FloatVariable,
IntegerVariable,
ObjectSegment,
SecretVariable,
SegmentType,
StringVariable,
)
from core.variables.exc import VariableError
from core.variables.segments import (
ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
StringSegment,
)
from core.variables.types import SegmentType
from factories import variable_factory
from factories.variable_factory import TypeMismatchError, build_segment_with_type
def test_string_variable():
test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, StringVariable)
def test_integer_variable():
test_data = {"value_type": "number", "name": "test_int", "value": 42}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, IntegerVariable)
def test_float_variable():
test_data = {"value_type": "number", "name": "test_float", "value": 3.14}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, FloatVariable)
def test_secret_variable():
test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, SecretVariable)
def test_invalid_value_type():
test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"}
with pytest.raises(VariableError):
variable_factory.build_conversation_variable_from_mapping(test_data)
def test_build_a_blank_string():
result = variable_factory.build_conversation_variable_from_mapping(
{
"value_type": "string",
"name": "blank",
"value": "",
}
)
assert isinstance(result, StringVariable)
assert result.value == ""
def test_build_a_object_variable_with_none_value():
var = variable_factory.build_segment(
{
"key1": None,
}
)
assert isinstance(var, ObjectSegment)
assert var.value["key1"] is None
def test_object_variable():
mapping = {
"id": str(uuid4()),
"value_type": "object",
"name": "test_object",
"description": "Description of the variable.",
"value": {
"key1": "text",
"key2": 2,
},
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ObjectSegment)
assert isinstance(variable.value["key1"], str)
assert isinstance(variable.value["key2"], int)
def test_array_string_variable():
mapping = {
"id": str(uuid4()),
"value_type": "array[string]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
"text",
"text",
],
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ArrayStringVariable)
assert isinstance(variable.value[0], str)
assert isinstance(variable.value[1], str)
def test_array_number_variable():
mapping = {
"id": str(uuid4()),
"value_type": "array[number]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
1,
2.0,
],
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ArrayNumberVariable)
assert isinstance(variable.value[0], int)
assert isinstance(variable.value[1], float)
def test_array_object_variable():
mapping = {
"id": str(uuid4()),
"value_type": "array[object]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
{
"key1": "text",
"key2": 1,
},
{
"key1": "text",
"key2": 1,
},
],
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ArrayObjectVariable)
assert isinstance(variable.value[0], dict)
assert isinstance(variable.value[1], dict)
assert isinstance(variable.value[0]["key1"], str)
assert isinstance(variable.value[0]["key2"], int)
assert isinstance(variable.value[1]["key1"], str)
assert isinstance(variable.value[1]["key2"], int)
def test_variable_cannot_large_than_200_kb():
with pytest.raises(VariableError):
variable_factory.build_conversation_variable_from_mapping(
{
"id": str(uuid4()),
"value_type": "string",
"name": "test_text",
"value": "a" * 1024 * 201,
}
)
def test_array_none_variable():
var = variable_factory.build_segment([None, None, None, None])
assert isinstance(var, ArrayAnySegment)
assert var.value == [None, None, None, None]
def test_build_segment_none_type():
"""Test building NoneSegment from None value."""
segment = variable_factory.build_segment(None)
assert isinstance(segment, NoneSegment)
assert segment.value is None
assert segment.value_type == SegmentType.NONE
def test_build_segment_none_type_properties():
"""Test NoneSegment properties and methods."""
segment = variable_factory.build_segment(None)
assert segment.text == ""
assert segment.log == ""
assert segment.markdown == ""
assert segment.to_object() is None
def test_build_segment_array_file_single_file():
"""Test building ArrayFileSegment from list with single file."""
file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://test.example.com/test-file.png",
filename="test-file",
extension=".png",
mime_type="image/png",
size=1000,
)
segment = variable_factory.build_segment([file])
assert isinstance(segment, ArrayFileSegment)
assert len(segment.value) == 1
assert segment.value[0] == file
assert segment.value_type == SegmentType.ARRAY_FILE
def test_build_segment_array_file_multiple_files():
"""Test building ArrayFileSegment from list with multiple files."""
file1 = File(
id="test_file_id_1",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://test.example.com/test-file1.png",
filename="test-file1",
extension=".png",
mime_type="image/png",
size=1000,
)
file2 = File(
id="test_file_id_2",
tenant_id="test_tenant_id",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test_relation_id",
filename="test-file2",
extension=".txt",
mime_type="text/plain",
size=500,
)
segment = variable_factory.build_segment([file1, file2])
assert isinstance(segment, ArrayFileSegment)
assert len(segment.value) == 2
assert segment.value[0] == file1
assert segment.value[1] == file2
assert segment.value_type == SegmentType.ARRAY_FILE
def test_build_segment_array_file_empty_list():
"""Test building ArrayFileSegment from empty list should create ArrayAnySegment."""
segment = variable_factory.build_segment([])
assert isinstance(segment, ArrayAnySegment)
assert segment.value == []
assert segment.value_type == SegmentType.ARRAY_ANY
def test_build_segment_array_any_mixed_types():
"""Test building ArrayAnySegment from list with mixed types."""
mixed_values = ["string", 42, 3.14, {"key": "value"}, None]
segment = variable_factory.build_segment(mixed_values)
assert isinstance(segment, ArrayAnySegment)
assert segment.value == mixed_values
assert segment.value_type == SegmentType.ARRAY_ANY
def test_build_segment_array_any_with_nested_arrays():
"""Test building ArrayAnySegment from list containing arrays."""
nested_values = [["nested", "array"], [1, 2, 3], "string"]
segment = variable_factory.build_segment(nested_values)
assert isinstance(segment, ArrayAnySegment)
assert segment.value == nested_values
assert segment.value_type == SegmentType.ARRAY_ANY
def test_build_segment_array_any_mixed_with_files():
"""Test building ArrayAnySegment from list with files and other types."""
file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://test.example.com/test-file.png",
filename="test-file",
extension=".png",
mime_type="image/png",
size=1000,
)
mixed_values = [file, "string", 42]
segment = variable_factory.build_segment(mixed_values)
assert isinstance(segment, ArrayAnySegment)
assert segment.value == mixed_values
assert segment.value_type == SegmentType.ARRAY_ANY
def test_build_segment_array_any_all_none_values():
"""Test building ArrayAnySegment from list with all None values."""
none_values = [None, None, None]
segment = variable_factory.build_segment(none_values)
assert isinstance(segment, ArrayAnySegment)
assert segment.value == none_values
assert segment.value_type == SegmentType.ARRAY_ANY
def test_build_segment_array_file_properties():
"""Test ArrayFileSegment properties and methods."""
file1 = File(
id="test_file_id_1",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://test.example.com/test-file1.png",
filename="test-file1",
extension=".png",
mime_type="image/png",
size=1000,
)
file2 = File(
id="test_file_id_2",
tenant_id="test_tenant_id",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://test.example.com/test-file2.txt",
filename="test-file2",
extension=".txt",
mime_type="text/plain",
size=500,
)
segment = variable_factory.build_segment([file1, file2])
# Test properties
assert segment.text == "" # ArrayFileSegment text property returns empty string
assert segment.log == "" # ArrayFileSegment log property returns empty string
assert segment.markdown == f"{file1.markdown}\n{file2.markdown}"
assert segment.to_object() == [file1, file2]
def test_build_segment_array_any_properties():
"""Test ArrayAnySegment properties and methods."""
mixed_values = ["string", 42, None]
segment = variable_factory.build_segment(mixed_values)
# Test properties
assert segment.text == str(mixed_values)
assert segment.log == str(mixed_values)
assert segment.markdown == "string\n42\nNone"
assert segment.to_object() == mixed_values
def test_build_segment_edge_cases():
"""Test edge cases for build_segment function."""
# Test with complex nested structures
complex_structure = [{"nested": {"deep": [1, 2, 3]}}, [{"inner": "value"}], "mixed"]
segment = variable_factory.build_segment(complex_structure)
assert isinstance(segment, ArrayAnySegment)
assert segment.value == complex_structure
# Test with single None in list
single_none = [None]
segment = variable_factory.build_segment(single_none)
assert isinstance(segment, ArrayAnySegment)
assert segment.value == single_none
def test_build_segment_file_array_with_different_file_types():
"""Test ArrayFileSegment with different file types."""
image_file = File(
id="image_id",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://test.example.com/image.png",
filename="image",
extension=".png",
mime_type="image/png",
size=1000,
)
video_file = File(
id="video_id",
tenant_id="test_tenant_id",
type=FileType.VIDEO,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="video_relation_id",
filename="video",
extension=".mp4",
mime_type="video/mp4",
size=5000,
)
audio_file = File(
id="audio_id",
tenant_id="test_tenant_id",
type=FileType.AUDIO,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="audio_relation_id",
filename="audio",
extension=".mp3",
mime_type="audio/mpeg",
size=3000,
)
segment = variable_factory.build_segment([image_file, video_file, audio_file])
assert isinstance(segment, ArrayFileSegment)
assert len(segment.value) == 3
assert segment.value[0].type == FileType.IMAGE
assert segment.value[1].type == FileType.VIDEO
assert segment.value[2].type == FileType.AUDIO
@st.composite
def _generate_file(draw) -> File:
file_id = draw(st.text(min_size=1, max_size=10))
tenant_id = draw(st.text(min_size=1, max_size=10))
file_type, mime_type, extension = draw(
st.sampled_from(
[
(FileType.IMAGE, "image/png", ".png"),
(FileType.VIDEO, "video/mp4", ".mp4"),
(FileType.DOCUMENT, "text/plain", ".txt"),
(FileType.AUDIO, "audio/mpeg", ".mp3"),
]
)
)
filename = "test-file"
size = draw(st.integers(min_value=0, max_value=1024 * 1024))
transfer_method = draw(st.sampled_from(list(FileTransferMethod)))
if transfer_method == FileTransferMethod.REMOTE_URL:
url = "https://test.example.com/test-file"
file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=file_type,
transfer_method=transfer_method,
remote_url=url,
related_id=None,
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
)
else:
relation_id = draw(st.uuids(version=4))
file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=file_type,
transfer_method=transfer_method,
related_id=str(relation_id),
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
)
return file
def _scalar_value() -> st.SearchStrategy[int | float | str | File | None]:
return st.one_of(
st.none(),
st.integers(),
st.floats(),
st.text(),
_generate_file(),
)
@given(_scalar_value())
def test_build_segment_and_extract_values_for_scalar_types(value):
seg = variable_factory.build_segment(value)
# nan == nan yields false, so we need to use `math.isnan` to check `seg.value` here.
if isinstance(value, float) and math.isnan(value):
assert math.isnan(seg.value)
else:
assert seg.value == value
@given(st.lists(_scalar_value()))
def test_build_segment_and_extract_values_for_array_types(values):
seg = variable_factory.build_segment(values)
assert seg.value == values
def test_build_segment_type_for_scalar():
@dataclass(frozen=True)
class TestCase:
value: int | float | str | File
expected_type: SegmentType
file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://test.example.com/test-file.png",
filename="test-file",
extension=".png",
mime_type="image/png",
size=1000,
)
cases = [
TestCase(0, SegmentType.NUMBER),
TestCase(0.0, SegmentType.NUMBER),
TestCase("", SegmentType.STRING),
TestCase(file, SegmentType.FILE),
]
for idx, c in enumerate(cases, 1):
segment = variable_factory.build_segment(c.value)
assert segment.value_type == c.expected_type, f"test case {idx} failed."
class TestBuildSegmentWithType:
"""Test cases for build_segment_with_type function."""
def test_string_type(self):
"""Test building a string segment with correct type."""
result = build_segment_with_type(SegmentType.STRING, "hello")
assert isinstance(result, StringSegment)
assert result.value == "hello"
assert result.value_type == SegmentType.STRING
def test_number_type_integer(self):
"""Test building a number segment with integer value."""
result = build_segment_with_type(SegmentType.NUMBER, 42)
assert isinstance(result, IntegerSegment)
assert result.value == 42
assert result.value_type == SegmentType.NUMBER
def test_number_type_float(self):
"""Test building a number segment with float value."""
result = build_segment_with_type(SegmentType.NUMBER, 3.14)
assert isinstance(result, FloatSegment)
assert result.value == 3.14
assert result.value_type == SegmentType.NUMBER
def test_object_type(self):
"""Test building an object segment with correct type."""
test_obj = {"key": "value", "nested": {"inner": 123}}
result = build_segment_with_type(SegmentType.OBJECT, test_obj)
assert isinstance(result, ObjectSegment)
assert result.value == test_obj
assert result.value_type == SegmentType.OBJECT
def test_file_type(self):
"""Test building a file segment with correct type."""
test_file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://test.example.com/test-file.png",
filename="test-file",
extension=".png",
mime_type="image/png",
size=1000,
storage_key="test_storage_key",
)
result = build_segment_with_type(SegmentType.FILE, test_file)
assert isinstance(result, FileSegment)
assert result.value == test_file
assert result.value_type == SegmentType.FILE
def test_none_type(self):
"""Test building a none segment with None value."""
result = build_segment_with_type(SegmentType.NONE, None)
assert isinstance(result, NoneSegment)
assert result.value is None
assert result.value_type == SegmentType.NONE
def test_empty_array_string(self):
"""Test building an empty array[string] segment."""
result = build_segment_with_type(SegmentType.ARRAY_STRING, [])
assert isinstance(result, ArrayStringSegment)
assert result.value == []
assert result.value_type == SegmentType.ARRAY_STRING
def test_empty_array_number(self):
"""Test building an empty array[number] segment."""
result = build_segment_with_type(SegmentType.ARRAY_NUMBER, [])
assert isinstance(result, ArrayNumberSegment)
assert result.value == []
assert result.value_type == SegmentType.ARRAY_NUMBER
def test_empty_array_object(self):
"""Test building an empty array[object] segment."""
result = build_segment_with_type(SegmentType.ARRAY_OBJECT, [])
assert isinstance(result, ArrayObjectSegment)
assert result.value == []
assert result.value_type == SegmentType.ARRAY_OBJECT
def test_empty_array_file(self):
"""Test building an empty array[file] segment."""
result = build_segment_with_type(SegmentType.ARRAY_FILE, [])
assert isinstance(result, ArrayFileSegment)
assert result.value == []
assert result.value_type == SegmentType.ARRAY_FILE
def test_empty_array_any(self):
"""Test building an empty array[any] segment."""
result = build_segment_with_type(SegmentType.ARRAY_ANY, [])
assert isinstance(result, ArrayAnySegment)
assert result.value == []
assert result.value_type == SegmentType.ARRAY_ANY
def test_array_with_values(self):
"""Test building array segments with actual values."""
# Array of strings
result = build_segment_with_type(SegmentType.ARRAY_STRING, ["hello", "world"])
assert isinstance(result, ArrayStringSegment)
assert result.value == ["hello", "world"]
assert result.value_type == SegmentType.ARRAY_STRING
# Array of numbers
result = build_segment_with_type(SegmentType.ARRAY_NUMBER, [1, 2, 3.14])
assert isinstance(result, ArrayNumberSegment)
assert result.value == [1, 2, 3.14]
assert result.value_type == SegmentType.ARRAY_NUMBER
# Array of objects
result = build_segment_with_type(SegmentType.ARRAY_OBJECT, [{"a": 1}, {"b": 2}])
assert isinstance(result, ArrayObjectSegment)
assert result.value == [{"a": 1}, {"b": 2}]
assert result.value_type == SegmentType.ARRAY_OBJECT
def test_type_mismatch_string_to_number(self):
"""Test type mismatch when expecting number but getting string."""
with pytest.raises(TypeMismatchError) as exc_info:
build_segment_with_type(SegmentType.NUMBER, "not_a_number")
assert "Type mismatch" in str(exc_info.value)
assert "expected number" in str(exc_info.value)
assert "str" in str(exc_info.value)
def test_type_mismatch_number_to_string(self):
"""Test type mismatch when expecting string but getting number."""
with pytest.raises(TypeMismatchError) as exc_info:
build_segment_with_type(SegmentType.STRING, 123)
assert "Type mismatch" in str(exc_info.value)
assert "expected string" in str(exc_info.value)
assert "int" in str(exc_info.value)
def test_type_mismatch_none_to_string(self):
"""Test type mismatch when expecting string but getting None."""
with pytest.raises(TypeMismatchError) as exc_info:
build_segment_with_type(SegmentType.STRING, None)
assert "Expected string, but got None" in str(exc_info.value)
def test_type_mismatch_empty_list_to_non_array(self):
"""Test type mismatch when expecting non-array type but getting empty list."""
with pytest.raises(TypeMismatchError) as exc_info:
build_segment_with_type(SegmentType.STRING, [])
assert "Expected string, but got empty list" in str(exc_info.value)
def test_type_mismatch_object_to_array(self):
"""Test type mismatch when expecting array but getting object."""
with pytest.raises(TypeMismatchError) as exc_info:
build_segment_with_type(SegmentType.ARRAY_STRING, {"key": "value"})
assert "Type mismatch" in str(exc_info.value)
assert "expected array[string]" in str(exc_info.value)
def test_compatible_number_types(self):
"""Test that int and float are both compatible with NUMBER type."""
# Integer should work
result_int = build_segment_with_type(SegmentType.NUMBER, 42)
assert isinstance(result_int, IntegerSegment)
assert result_int.value_type == SegmentType.NUMBER
# Float should work
result_float = build_segment_with_type(SegmentType.NUMBER, 3.14)
assert isinstance(result_float, FloatSegment)
assert result_float.value_type == SegmentType.NUMBER
@pytest.mark.parametrize(
("segment_type", "value", "expected_class"),
[
(SegmentType.STRING, "test", StringSegment),
(SegmentType.NUMBER, 42, IntegerSegment),
(SegmentType.NUMBER, 3.14, FloatSegment),
(SegmentType.OBJECT, {}, ObjectSegment),
(SegmentType.NONE, None, NoneSegment),
(SegmentType.ARRAY_STRING, [], ArrayStringSegment),
(SegmentType.ARRAY_NUMBER, [], ArrayNumberSegment),
(SegmentType.ARRAY_OBJECT, [], ArrayObjectSegment),
(SegmentType.ARRAY_ANY, [], ArrayAnySegment),
],
)
def test_parametrized_valid_types(self, segment_type, value, expected_class):
"""Parametrized test for valid type combinations."""
result = build_segment_with_type(segment_type, value)
assert isinstance(result, expected_class)
assert result.value == value
assert result.value_type == segment_type
@pytest.mark.parametrize(
("segment_type", "value"),
[
(SegmentType.STRING, 123),
(SegmentType.NUMBER, "not_a_number"),
(SegmentType.OBJECT, "not_an_object"),
(SegmentType.ARRAY_STRING, "not_an_array"),
(SegmentType.STRING, None),
(SegmentType.NUMBER, None),
],
)
def test_parametrized_type_mismatches(self, segment_type, value):
"""Parametrized test for type mismatches that should raise TypeMismatchError."""
with pytest.raises(TypeMismatchError):
build_segment_with_type(segment_type, value)
# Test cases for ValueError scenarios in build_segment function
class TestBuildSegmentValueErrors:
"""Test cases for ValueError scenarios in the build_segment function."""
@dataclass(frozen=True)
class ValueErrorTestCase:
"""Test case data for ValueError scenarios."""
name: str
description: str
test_value: Any
def _get_test_cases(self):
"""Get all test cases for ValueError scenarios."""
# Define inline classes for complex test cases
class CustomType:
pass
def unsupported_function():
return "test"
def gen():
yield 1
yield 2
return [
self.ValueErrorTestCase(
name="unsupported_custom_type",
description="custom class that doesn't match any supported type",
test_value=CustomType(),
),
self.ValueErrorTestCase(
name="unsupported_set_type",
description="set (unsupported collection type)",
test_value={1, 2, 3},
),
self.ValueErrorTestCase(
name="unsupported_tuple_type", description="tuple (unsupported type)", test_value=(1, 2, 3)
),
self.ValueErrorTestCase(
name="unsupported_bytes_type",
description="bytes (unsupported type)",
test_value=b"hello world",
),
self.ValueErrorTestCase(
name="unsupported_function_type",
description="function (unsupported type)",
test_value=unsupported_function,
),
self.ValueErrorTestCase(
name="unsupported_module_type", description="module (unsupported type)", test_value=math
),
self.ValueErrorTestCase(
name="array_with_unsupported_element_types",
description="array with unsupported element types",
test_value=[CustomType()],
),
self.ValueErrorTestCase(
name="mixed_array_with_unsupported_types",
description="array with mix of supported and unsupported types",
test_value=["valid_string", 42, CustomType()],
),
self.ValueErrorTestCase(
name="nested_unsupported_types",
description="nested structures containing unsupported types",
test_value=[{"valid": "data"}, CustomType()],
),
self.ValueErrorTestCase(
name="complex_number_type",
description="complex number (unsupported type)",
test_value=3 + 4j,
),
self.ValueErrorTestCase(
name="range_type", description="range object (unsupported type)", test_value=range(10)
),
self.ValueErrorTestCase(
name="generator_type",
description="generator (unsupported type)",
test_value=gen(),
),
self.ValueErrorTestCase(
name="exception_message_contains_value",
description="set to verify error message contains the actual unsupported value",
test_value={1, 2, 3},
),
self.ValueErrorTestCase(
name="array_with_mixed_unsupported_segment_types",
description="array processing with unsupported segment types in match",
test_value=[CustomType()],
),
self.ValueErrorTestCase(
name="frozenset_type",
description="frozenset (unsupported type)",
test_value=frozenset([1, 2, 3]),
),
self.ValueErrorTestCase(
name="memoryview_type",
description="memoryview (unsupported type)",
test_value=memoryview(b"hello"),
),
self.ValueErrorTestCase(
name="slice_type", description="slice object (unsupported type)", test_value=slice(1, 10, 2)
),
self.ValueErrorTestCase(name="type_object", description="type object (unsupported type)", test_value=type),
self.ValueErrorTestCase(
name="generic_object", description="generic object (unsupported type)", test_value=object()
),
]
def test_build_segment_unsupported_types(self):
"""Table-driven test for all ValueError scenarios in build_segment function."""
test_cases = self._get_test_cases()
for index, test_case in enumerate(test_cases, 1):
# Use test value directly
test_value = test_case.test_value
with pytest.raises(ValueError) as exc_info: # noqa: PT012
segment = variable_factory.build_segment(test_value)
pytest.fail(f"Test case {index} ({test_case.name}) should raise ValueError but not, result={segment}")
error_message = str(exc_info.value)
assert "not supported value" in error_message, (
f"Test case {index} ({test_case.name}): Expected 'not supported value' in error message, "
f"but got: {error_message}"
)
def test_build_segment_boolean_type_note(self):
"""Note: Boolean values are actually handled as integers in Python, so they don't raise ValueError."""
# Boolean values in Python are subclasses of int, so they get processed as integers
# True becomes IntegerSegment(value=1) and False becomes IntegerSegment(value=0)
true_segment = variable_factory.build_segment(True)
false_segment = variable_factory.build_segment(False)
# Verify they are processed as integers, not as errors
assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1"
assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0"
assert true_segment.value_type == SegmentType.NUMBER
assert false_segment.value_type == SegmentType.NUMBER

View File

@@ -0,0 +1,20 @@
import datetime
from libs.datetime_utils import naive_utc_now
def test_naive_utc_now(monkeypatch):
tz_aware_utc_now = datetime.datetime.now(tz=datetime.UTC)
def _now_func(tz: datetime.timezone | None) -> datetime.datetime:
return tz_aware_utc_now.astimezone(tz)
monkeypatch.setattr("libs.datetime_utils._now_func", _now_func)
naive_datetime = naive_utc_now()
assert naive_datetime.tzinfo is None
assert naive_datetime.date() == tz_aware_utc_now.date()
naive_time = naive_datetime.time()
utc_time = tz_aware_utc_now.time()
assert naive_time == utc_time

View File

View File

@@ -1,10 +1,15 @@
import dataclasses
import json
from unittest import mock
from uuid import uuid4
from constants import HIDDEN_VALUE
from core.file.enums import FileTransferMethod, FileType
from core.file.models import File
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
from models.workflow import Workflow, WorkflowNodeExecutionModel
from core.variables.segments import IntegerSegment, Segment
from factories.variable_factory import build_segment
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
def test_environment_variables():
@@ -163,3 +168,147 @@ class TestWorkflowNodeExecution:
original = {"a": 1, "b": ["2"]}
node_exec.execution_metadata = json.dumps(original)
assert node_exec.execution_metadata_dict == original
class TestIsSystemVariableEditable:
def test_is_system_variable(self):
cases = [
("query", True),
("files", True),
("dialogue_count", False),
("conversation_id", False),
("user_id", False),
("app_id", False),
("workflow_id", False),
("workflow_run_id", False),
]
for name, editable in cases:
assert editable == is_system_variable_editable(name)
assert is_system_variable_editable("invalid_or_new_system_variable") == False
class TestWorkflowDraftVariableGetValue:
def test_get_value_by_case(self):
@dataclasses.dataclass
class TestCase:
name: str
value: Segment
tenant_id = "test_tenant_id"
test_file = File(
tenant_id=tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/example.jpg",
filename="example.jpg",
extension=".jpg",
mime_type="image/jpeg",
size=100,
)
cases: list[TestCase] = [
TestCase(
name="number/int",
value=build_segment(1),
),
TestCase(
name="number/float",
value=build_segment(1.0),
),
TestCase(
name="string",
value=build_segment("a"),
),
TestCase(
name="object",
value=build_segment({}),
),
TestCase(
name="file",
value=build_segment(test_file),
),
TestCase(
name="array[any]",
value=build_segment([1, "a"]),
),
TestCase(
name="array[string]",
value=build_segment(["a", "b"]),
),
TestCase(
name="array[number]/int",
value=build_segment([1, 2]),
),
TestCase(
name="array[number]/float",
value=build_segment([1.0, 2.0]),
),
TestCase(
name="array[number]/mixed",
value=build_segment([1, 2.0]),
),
TestCase(
name="array[object]",
value=build_segment([{}, {"a": 1}]),
),
TestCase(
name="none",
value=build_segment(None),
),
]
for idx, c in enumerate(cases, 1):
fail_msg = f"test case {c.name} failed, index={idx}"
draft_var = WorkflowDraftVariable()
draft_var.set_value(c.value)
assert c.value == draft_var.get_value(), fail_msg
def test_file_variable_preserves_all_fields(self):
"""Test that File type variables preserve all fields during encoding/decoding."""
tenant_id = "test_tenant_id"
# Create a File with specific field values
test_file = File(
id="test_file_id",
tenant_id=tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/test.jpg",
filename="test.jpg",
extension=".jpg",
mime_type="image/jpeg",
size=12345, # Specific size to test preservation
storage_key="test_storage_key",
)
# Create a FileSegment and WorkflowDraftVariable
file_segment = build_segment(test_file)
draft_var = WorkflowDraftVariable()
draft_var.set_value(file_segment)
# Retrieve the value and verify all fields are preserved
retrieved_segment = draft_var.get_value()
retrieved_file = retrieved_segment.value
# Verify all important fields are preserved
assert retrieved_file.id == test_file.id
assert retrieved_file.tenant_id == test_file.tenant_id
assert retrieved_file.type == test_file.type
assert retrieved_file.transfer_method == test_file.transfer_method
assert retrieved_file.remote_url == test_file.remote_url
assert retrieved_file.filename == test_file.filename
assert retrieved_file.extension == test_file.extension
assert retrieved_file.mime_type == test_file.mime_type
assert retrieved_file.size == test_file.size # This was the main issue being fixed
# Note: storage_key is not serialized in model_dump() so it won't be preserved
# Verify the segments have the same type and the important fields match
assert file_segment.value_type == retrieved_segment.value_type
def test_get_and_set_value(self):
draft_var = WorkflowDraftVariable()
int_var = IntegerSegment(value=1)
draft_var.set_value(int_var)
value = draft_var.get_value()
assert value == int_var

View File

@@ -0,0 +1,222 @@
import dataclasses
import secrets
from unittest import mock
from unittest.mock import Mock, patch
import pytest
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables.types import SegmentType
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes import NodeType
from models.enums import DraftVariableType
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel
from services.workflow_draft_variable_service import (
DraftVariableSaver,
VariableResetError,
WorkflowDraftVariableService,
)
class TestDraftVariableSaver:
def _get_test_app_id(self):
suffix = secrets.token_hex(6)
return f"test_app_id_{suffix}"
def test__should_variable_be_visible(self):
mock_session = mock.MagicMock(spec=Session)
test_app_id = self._get_test_app_id()
saver = DraftVariableSaver(
session=mock_session,
app_id=test_app_id,
node_id="test_node_id",
node_type=NodeType.START,
invoke_from=InvokeFrom.DEBUGGER,
node_execution_id="test_execution_id",
)
assert saver._should_variable_be_visible("123_456", NodeType.IF_ELSE, "output") == False
assert saver._should_variable_be_visible("123", NodeType.START, "output") == True
def test__normalize_variable_for_start_node(self):
@dataclasses.dataclass(frozen=True)
class TestCase:
name: str
input_node_id: str
input_name: str
expected_node_id: str
expected_name: str
_NODE_ID = "1747228642872"
cases = [
TestCase(
name="name with `sys.` prefix should return the system node_id",
input_node_id=_NODE_ID,
input_name="sys.workflow_id",
expected_node_id=SYSTEM_VARIABLE_NODE_ID,
expected_name="workflow_id",
),
TestCase(
name="name without `sys.` prefix should return the original input node_id",
input_node_id=_NODE_ID,
input_name="start_input",
expected_node_id=_NODE_ID,
expected_name="start_input",
),
TestCase(
name="dummy_variable should return the original input node_id",
input_node_id=_NODE_ID,
input_name="__dummy__",
expected_node_id=_NODE_ID,
expected_name="__dummy__",
),
]
mock_session = mock.MagicMock(spec=Session)
test_app_id = self._get_test_app_id()
saver = DraftVariableSaver(
session=mock_session,
app_id=test_app_id,
node_id=_NODE_ID,
node_type=NodeType.START,
invoke_from=InvokeFrom.DEBUGGER,
node_execution_id="test_execution_id",
)
for idx, c in enumerate(cases, 1):
fail_msg = f"Test case {c.name} failed, index={idx}"
node_id, name = saver._normalize_variable_for_start_node(c.input_name)
assert node_id == c.expected_node_id, fail_msg
assert name == c.expected_name, fail_msg
class TestWorkflowDraftVariableService:
def _get_test_app_id(self):
suffix = secrets.token_hex(6)
return f"test_app_id_{suffix}"
def test_reset_conversation_variable(self):
"""Test resetting a conversation variable"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
mock_workflow = Mock(spec=Workflow)
mock_workflow.app_id = self._get_test_app_id()
# Create mock variable
mock_variable = Mock(spec=WorkflowDraftVariable)
mock_variable.get_variable_type.return_value = DraftVariableType.CONVERSATION
mock_variable.id = "var-id"
mock_variable.name = "test_var"
# Mock the _reset_conv_var method
expected_result = Mock(spec=WorkflowDraftVariable)
with patch.object(service, "_reset_conv_var", return_value=expected_result) as mock_reset_conv:
result = service.reset_variable(mock_workflow, mock_variable)
mock_reset_conv.assert_called_once_with(mock_workflow, mock_variable)
assert result == expected_result
def test_reset_node_variable_with_no_execution_id(self):
"""Test resetting a node variable with no execution ID - should delete variable"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
mock_workflow = Mock(spec=Workflow)
mock_workflow.app_id = self._get_test_app_id()
# Create mock variable with no execution ID
mock_variable = Mock(spec=WorkflowDraftVariable)
mock_variable.get_variable_type.return_value = DraftVariableType.NODE
mock_variable.node_execution_id = None
mock_variable.id = "var-id"
mock_variable.name = "test_var"
result = service._reset_node_var(mock_workflow, mock_variable)
# Should delete the variable and return None
mock_session.delete.assert_called_once_with(instance=mock_variable)
mock_session.flush.assert_called_once()
assert result is None
def test_reset_node_variable_with_missing_execution_record(self):
"""Test resetting a node variable when execution record doesn't exist"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
mock_workflow = Mock(spec=Workflow)
mock_workflow.app_id = self._get_test_app_id()
# Create mock variable with execution ID
mock_variable = Mock(spec=WorkflowDraftVariable)
mock_variable.get_variable_type.return_value = DraftVariableType.NODE
mock_variable.node_execution_id = "exec-id"
mock_variable.id = "var-id"
mock_variable.name = "test_var"
# Mock session.scalars to return None (no execution record found)
mock_scalars = Mock()
mock_scalars.first.return_value = None
mock_session.scalars.return_value = mock_scalars
result = service._reset_node_var(mock_workflow, mock_variable)
# Should delete the variable and return None
mock_session.delete.assert_called_once_with(instance=mock_variable)
mock_session.flush.assert_called_once()
assert result is None
def test_reset_node_variable_with_valid_execution_record(self):
"""Test resetting a node variable with valid execution record - should restore from execution"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
mock_workflow = Mock(spec=Workflow)
mock_workflow.app_id = self._get_test_app_id()
# Create mock variable with execution ID
mock_variable = Mock(spec=WorkflowDraftVariable)
mock_variable.get_variable_type.return_value = DraftVariableType.NODE
mock_variable.node_execution_id = "exec-id"
mock_variable.id = "var-id"
mock_variable.name = "test_var"
mock_variable.node_id = "node-id"
mock_variable.value_type = SegmentType.STRING
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.process_data_dict = {"test_var": "process_value"}
mock_execution.outputs_dict = {"test_var": "output_value"}
# Mock session.scalars to return the execution record
mock_scalars = Mock()
mock_scalars.first.return_value = mock_execution
mock_session.scalars.return_value = mock_scalars
# Mock workflow methods
mock_node_config = {"type": "test_node"}
mock_workflow.get_node_config_by_id.return_value = mock_node_config
mock_workflow.get_node_type_from_node_config.return_value = NodeType.LLM
result = service._reset_node_var(mock_workflow, mock_variable)
# Verify variable.set_value was called with the correct value
mock_variable.set_value.assert_called_once()
# Verify last_edited_at was reset
assert mock_variable.last_edited_at is None
# Verify session.flush was called
mock_session.flush.assert_called()
# Should return the updated variable
assert result == mock_variable
def test_reset_system_variable_raises_error(self):
"""Test that resetting a system variable raises an error"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
mock_workflow = Mock(spec=Workflow)
mock_workflow.app_id = self._get_test_app_id()
mock_variable = Mock(spec=WorkflowDraftVariable)
mock_variable.get_variable_type.return_value = DraftVariableType.SYS # Not a valid enum value for this test
mock_variable.id = "var-id"
with pytest.raises(VariableResetError) as exc_info:
service.reset_variable(mock_workflow, mock_variable)
assert "cannot reset system variable" in str(exc_info.value)
assert "variable_id=var-id" in str(exc_info.value)