feat: Persist Variables for Enhanced Debugging Workflow (#20699)

This pull request introduces a feature aimed at improving the debugging experience during workflow editing. With the addition of variable persistence, the system will automatically retain the output variables from previously executed nodes. These persisted variables can then be reused when debugging subsequent nodes, eliminating the need for repetitive manual input.

By streamlining this aspect of the workflow, the feature minimizes user errors and significantly reduces debugging effort, offering a smoother and more efficient experience.

Key highlights of this change:

- Automatic persistence of output variables for executed nodes.
- Reuse of persisted variables to simplify input steps for nodes requiring them (e.g., `code`, `template`, `variable_assigner`).
- Enhanced debugging experience with reduced friction.

Closes #19735.
This commit is contained in:
QuantumGhost
2025-06-24 09:05:29 +08:00
committed by GitHub
parent 3113350e51
commit 10b738a296
106 changed files with 6025 additions and 718 deletions

View File

@@ -1,107 +1,217 @@
# OpenAI API Key
OPENAI_API_KEY=
FLASK_APP=app.py
FLASK_DEBUG=0
SECRET_KEY='uhySf6a3aZuvRNfAlcr47paOw9TRYBY6j8ZHXpVw1yx5RP27Yj3w2uvI'
# Azure OpenAI API Base Endpoint & API Key
AZURE_OPENAI_API_BASE=
AZURE_OPENAI_API_KEY=
CONSOLE_API_URL=http://127.0.0.1:5001
CONSOLE_WEB_URL=http://127.0.0.1:3000
# Anthropic API Key
ANTHROPIC_API_KEY=
# Service API base URL
SERVICE_API_URL=http://127.0.0.1:5001
# Replicate API Key
REPLICATE_API_KEY=
# Web APP base URL
APP_WEB_URL=http://127.0.0.1:3000
# Hugging Face API Key
HUGGINGFACE_API_KEY=
HUGGINGFACE_TEXT_GEN_ENDPOINT_URL=
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL=
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL=
# Files URL
FILES_URL=http://127.0.0.1:5001
# Minimax Credentials
MINIMAX_API_KEY=
MINIMAX_GROUP_ID=
# The time in seconds after the signature is rejected
FILES_ACCESS_TIMEOUT=300
# Spark Credentials
SPARK_APP_ID=
SPARK_API_KEY=
SPARK_API_SECRET=
# Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60
# Tongyi Credentials
TONGYI_DASHSCOPE_API_KEY=
# Refresh token expiration time in days
REFRESH_TOKEN_EXPIRE_DAYS=30
# Wenxin Credentials
WENXIN_API_KEY=
WENXIN_SECRET_KEY=
# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
# ZhipuAI Credentials
ZHIPUAI_API_KEY=
# redis configuration
REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_USERNAME=
REDIS_PASSWORD=difyai123456
REDIS_USE_SSL=false
REDIS_DB=0
# Baichuan Credentials
BAICHUAN_API_KEY=
BAICHUAN_SECRET_KEY=
# PostgreSQL database configuration
DB_USERNAME=postgres
DB_PASSWORD=difyai123456
DB_HOST=localhost
DB_PORT=5432
DB_DATABASE=dify
# ChatGLM Credentials
CHATGLM_API_BASE=
# Storage configuration
# use for store upload files, private keys...
# storage type: opendal, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase
STORAGE_TYPE=opendal
# Xinference Credentials
XINFERENCE_SERVER_URL=
XINFERENCE_GENERATION_MODEL_UID=
XINFERENCE_CHAT_MODEL_UID=
XINFERENCE_EMBEDDINGS_MODEL_UID=
XINFERENCE_RERANK_MODEL_UID=
# Apache OpenDAL storage configuration, refer to https://github.com/apache/opendal
OPENDAL_SCHEME=fs
OPENDAL_FS_ROOT=storage
# OpenLLM Credentials
OPENLLM_SERVER_URL=
# CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# LocalAI Credentials
LOCALAI_SERVER_URL=
# Vector database configuration
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase
VECTOR_STORE=weaviate
# Weaviate configuration
WEAVIATE_ENDPOINT=http://localhost:8080
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
WEAVIATE_GRPC_ENABLED=false
WEAVIATE_BATCH_SIZE=100
# Cohere Credentials
COHERE_API_KEY=
# Jina Credentials
JINA_API_KEY=
# Upload configuration
UPLOAD_FILE_SIZE_LIMIT=15
UPLOAD_FILE_BATCH_LIMIT=5
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
# Ollama Credentials
OLLAMA_BASE_URL=
# Model configuration
MULTIMODAL_SEND_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=4096
CODE_GENERATION_MAX_TOKENS=1024
# Together API Key
TOGETHER_API_KEY=
# Mail configuration, support: resend, smtp
MAIL_TYPE=
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@example.com>
RESEND_API_KEY=
RESEND_API_URL=https://api.resend.com
# smtp configuration
SMTP_SERVER=smtp.example.com
SMTP_PORT=465
SMTP_USERNAME=123
SMTP_PASSWORD=abc
SMTP_USE_TLS=true
SMTP_OPPORTUNISTIC_TLS=false
# Mock Switch
MOCK_SWITCH=false
# Sentry configuration
SENTRY_DSN=
# DEBUG
DEBUG=false
SQLALCHEMY_ECHO=false
# Notion import configuration, support public and internal
NOTION_INTEGRATION_TYPE=public
NOTION_CLIENT_SECRET=you-client-secret
NOTION_CLIENT_ID=you-client-id
NOTION_INTERNAL_SECRET=you-internal-secret
ETL_TYPE=dify
UNSTRUCTURED_API_URL=
UNSTRUCTURED_API_KEY=
SCARF_NO_ANALYTICS=false
#ssrf
SSRF_PROXY_HTTP_URL=
SSRF_PROXY_HTTPS_URL=
SSRF_DEFAULT_MAX_RETRIES=3
SSRF_DEFAULT_TIME_OUT=5
SSRF_DEFAULT_CONNECT_TIME_OUT=5
SSRF_DEFAULT_READ_TIME_OUT=5
SSRF_DEFAULT_WRITE_TIME_OUT=5
BATCH_UPLOAD_LIMIT=10
KEYWORD_DATA_SOURCE_TYPE=database
# Workflow file upload limit
WORKFLOW_FILE_UPLOAD_LIMIT=10
# CODE EXECUTION CONFIGURATION
CODE_EXECUTION_ENDPOINT=
CODE_EXECUTION_API_KEY=
CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194
CODE_EXECUTION_API_KEY=dify-sandbox
CODE_MAX_NUMBER=9223372036854775807
CODE_MIN_NUMBER=-9223372036854775808
CODE_MAX_STRING_LENGTH=80000
TEMPLATE_TRANSFORM_MAX_LENGTH=80000
CODE_MAX_STRING_ARRAY_LENGTH=30
CODE_MAX_OBJECT_ARRAY_LENGTH=30
CODE_MAX_NUMBER_ARRAY_LENGTH=1000
# Volcengine MaaS Credentials
VOLC_API_KEY=
VOLC_SECRET_KEY=
VOLC_MODEL_ENDPOINT_ID=
VOLC_EMBEDDING_ENDPOINT_ID=
# API Tool configuration
API_TOOL_DEFAULT_CONNECT_TIMEOUT=10
API_TOOL_DEFAULT_READ_TIMEOUT=60
# 360 AI Credentials
ZHINAO_API_KEY=
# HTTP Node configuration
HTTP_REQUEST_MAX_CONNECT_TIMEOUT=300
HTTP_REQUEST_MAX_READ_TIMEOUT=600
HTTP_REQUEST_MAX_WRITE_TIMEOUT=600
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
# Respect X-* headers to redirect clients
RESPECT_XFORWARD_HEADERS_ENABLED=false
# Log file path
LOG_FILE=
# Log file max size, the unit is MB
LOG_FILE_MAX_SIZE=20
# Log file max backup count
LOG_FILE_BACKUP_COUNT=5
# Log dateformat
LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
# Log Timezone
LOG_TZ=UTC
# Log format
LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s
# Indexing configuration
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
# Workflow runtime configuration
WORKFLOW_MAX_EXECUTION_STEPS=500
WORKFLOW_MAX_EXECUTION_TIME=1200
WORKFLOW_CALL_MAX_DEPTH=5
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
MAX_VARIABLE_SIZE=204800
# App configuration
APP_MAX_EXECUTION_TIME=1200
APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1
# Position configuration
POSITION_TOOL_PINS=
POSITION_TOOL_INCLUDES=
POSITION_TOOL_EXCLUDES=
POSITION_PROVIDER_PINS=
POSITION_PROVIDER_INCLUDES=
POSITION_PROVIDER_EXCLUDES=
# Plugin configuration
PLUGIN_DAEMON_KEY=
PLUGIN_DAEMON_URL=
PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi
PLUGIN_DAEMON_URL=http://127.0.0.1:5002
PLUGIN_REMOTE_INSTALL_PORT=5003
PLUGIN_REMOTE_INSTALL_HOST=localhost
PLUGIN_MAX_PACKAGE_SIZE=15728640
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
# Marketplace configuration
MARKETPLACE_API_URL=
# VESSL AI Credentials
VESSL_AI_MODEL_NAME=
VESSL_AI_API_KEY=
VESSL_AI_ENDPOINT_URL=
MARKETPLACE_ENABLED=true
MARKETPLACE_API_URL=https://marketplace.dify.ai
# GPUStack Credentials
GPUSTACK_SERVER_URL=
GPUSTACK_API_KEY=
# Endpoint configuration
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
# Gitee AI Credentials
GITEE_AI_API_KEY=
# Reset password token expiry minutes
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
# xAI Credentials
XAI_API_KEY=
XAI_API_BASE=
CREATE_TIDB_SERVICE_JOB_ENABLED=false
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
MAX_SUBMIT_COUNT=100
# Lockout duration in seconds
LOGIN_LOCKOUT_DURATION=86400
HTTP_PROXY='http://127.0.0.1:1092'
HTTPS_PROXY='http://127.0.0.1:1092'
NO_PROXY='localhost,127.0.0.1'
LOG_LEVEL=INFO

View File

@@ -1,19 +1,91 @@
import os
import pathlib
import random
import secrets
from collections.abc import Generator
# Getting the absolute path of the current file's directory
ABS_PATH = os.path.dirname(os.path.abspath(__file__))
import pytest
from flask import Flask
from flask.testing import FlaskClient
from sqlalchemy.orm import Session
# Getting the absolute path of the project's root directory
PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir))
from app_factory import create_app
from models import Account, DifySetup, Tenant, TenantAccountJoin, db
from services.account_service import AccountService, RegisterService
# Loading the .env file if it exists
def _load_env() -> None:
dotenv_path = os.path.join(PROJECT_DIR, "tests", "integration_tests", ".env")
if os.path.exists(dotenv_path):
current_file_path = pathlib.Path(__file__).absolute()
# Items later in the list have higher precedence.
files_to_load = [".env", "vdb.env"]
env_file_paths = [current_file_path.parent / i for i in files_to_load]
for path in env_file_paths:
if not path.exists():
continue
from dotenv import load_dotenv
load_dotenv(dotenv_path)
# Set `override=True` to ensure values from `vdb.env` take priority over values from `.env`
load_dotenv(str(path), override=True)
_load_env()
_CACHED_APP = create_app()
@pytest.fixture
def flask_app() -> Flask:
return _CACHED_APP
@pytest.fixture(scope="session")
def setup_account(request) -> Generator[Account, None, None]:
"""`dify_setup` completes the setup process for the Dify application.
It creates `Account` and `Tenant`, and inserts a `DifySetup` record into the database.
Most tests in the `controllers` package may require dify has been successfully setup.
"""
with _CACHED_APP.test_request_context():
rand_suffix = random.randint(int(1e6), int(1e7)) # noqa
name = f"test-user-{rand_suffix}"
email = f"{name}@example.com"
RegisterService.setup(
email=email,
name=name,
password=secrets.token_hex(16),
ip_address="localhost",
)
with _CACHED_APP.test_request_context():
with Session(bind=db.engine, expire_on_commit=False) as session:
account = session.query(Account).filter_by(email=email).one()
yield account
with _CACHED_APP.test_request_context():
db.session.query(DifySetup).delete()
db.session.query(TenantAccountJoin).delete()
db.session.query(Account).delete()
db.session.query(Tenant).delete()
db.session.commit()
@pytest.fixture
def flask_req_ctx():
with _CACHED_APP.test_request_context():
yield
@pytest.fixture
def auth_header(setup_account) -> dict[str, str]:
token = AccountService.get_account_jwt_token(setup_account)
return {"Authorization": f"Bearer {token}"}
@pytest.fixture
def test_client() -> Generator[FlaskClient, None, None]:
with _CACHED_APP.test_client() as client:
yield client

View File

@@ -1,25 +0,0 @@
import pytest
from app_factory import create_app
from configs import dify_config
mock_user = type(
"MockUser",
(object,),
{
"is_authenticated": True,
"id": "123",
"is_editor": True,
"is_dataset_editor": True,
"status": "active",
"get_id": "123",
"current_tenant_id": "9d2074fc-6f86-45a9-b09d-6ecc63b9056b",
},
)
@pytest.fixture
def app():
app = create_app()
dify_config.LOGIN_DISABLED = True
return app

View File

@@ -0,0 +1,47 @@
import uuid
from unittest import mock
from controllers.console.app import workflow_draft_variable as draft_variable_api
from controllers.console.app import wraps
from factories.variable_factory import build_segment
from models import App, AppMode
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
def _get_mock_srv_class() -> type[WorkflowDraftVariableService]:
return mock.create_autospec(WorkflowDraftVariableService)
class TestWorkflowDraftNodeVariableListApi:
def test_get(self, test_client, auth_header, monkeypatch):
srv_class = _get_mock_srv_class()
mock_app_model: App = App()
mock_app_model.id = str(uuid.uuid4())
test_node_id = "test_node_id"
mock_app_model.mode = AppMode.ADVANCED_CHAT
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(draft_variable_api, "WorkflowDraftVariableService", srv_class)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
var1 = WorkflowDraftVariable.new_node_variable(
app_id="test_app_1",
node_id="test_node_1",
name="str_var",
value=build_segment("str_value"),
node_execution_id=str(uuid.uuid4()),
)
srv_instance = mock.create_autospec(WorkflowDraftVariableService, instance=True)
srv_class.return_value = srv_instance
srv_instance.list_node_variables.return_value = WorkflowDraftVariableList(variables=[var1])
response = test_client.get(
f"/console/api/apps/{mock_app_model.id}/workflows/draft/nodes/{test_node_id}/variables",
headers=auth_header,
)
assert response.status_code == 200
response_dict = response.json
assert isinstance(response_dict, dict)
assert "items" in response_dict
assert len(response_dict["items"]) == 1

View File

@@ -1,9 +0,0 @@
from unittest.mock import patch
from app_fixture import mock_user # type: ignore
def test_post_requires_login(app):
with app.test_client() as client, patch("flask_login.utils._get_user", mock_user):
response = client.get("/console/api/data-source/integrates")
assert response.status_code == 200

View File

@@ -0,0 +1,371 @@
import unittest
from datetime import UTC, datetime
from typing import Optional
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.file import File, FileTransferMethod, FileType
from extensions.ext_database import db
from factories.file_factory import StorageKeyLoader
from models import ToolFile, UploadFile
from models.enums import CreatorUserRole
@pytest.mark.usefixtures("flask_req_ctx")
class TestStorageKeyLoader(unittest.TestCase):
"""
Integration tests for StorageKeyLoader class.
Tests the batched loading of storage keys from the database for files
with different transfer methods: LOCAL_FILE, REMOTE_URL, and TOOL_FILE.
"""
def setUp(self):
"""Set up test data before each test method."""
self.session = db.session()
self.tenant_id = str(uuid4())
self.user_id = str(uuid4())
self.conversation_id = str(uuid4())
# Create test data that will be cleaned up after each test
self.test_upload_files = []
self.test_tool_files = []
# Create StorageKeyLoader instance
self.loader = StorageKeyLoader(self.session, self.tenant_id)
def tearDown(self):
"""Clean up test data after each test method."""
self.session.rollback()
def _create_upload_file(
self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None
) -> UploadFile:
"""Helper method to create an UploadFile record for testing."""
if file_id is None:
file_id = str(uuid4())
if storage_key is None:
storage_key = f"test_storage_key_{uuid4()}"
if tenant_id is None:
tenant_id = self.tenant_id
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type="local",
key=storage_key,
name="test_file.txt",
size=1024,
extension=".txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
created_at=datetime.now(UTC),
used=False,
)
upload_file.id = file_id
self.session.add(upload_file)
self.session.flush()
self.test_upload_files.append(upload_file)
return upload_file
def _create_tool_file(
self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None
) -> ToolFile:
"""Helper method to create a ToolFile record for testing."""
if file_id is None:
file_id = str(uuid4())
if file_key is None:
file_key = f"test_file_key_{uuid4()}"
if tenant_id is None:
tenant_id = self.tenant_id
tool_file = ToolFile()
tool_file.id = file_id
tool_file.user_id = self.user_id
tool_file.tenant_id = tenant_id
tool_file.conversation_id = self.conversation_id
tool_file.file_key = file_key
tool_file.mimetype = "text/plain"
tool_file.original_url = "http://example.com/file.txt"
tool_file.name = "test_tool_file.txt"
tool_file.size = 2048
self.session.add(tool_file)
self.session.flush()
self.test_tool_files.append(tool_file)
return tool_file
def _create_file(
self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None
) -> File:
"""Helper method to create a File object for testing."""
if tenant_id is None:
tenant_id = self.tenant_id
# Set related_id for LOCAL_FILE and TOOL_FILE transfer methods
file_related_id = None
remote_url = None
if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE):
file_related_id = related_id
elif transfer_method == FileTransferMethod.REMOTE_URL:
remote_url = "https://example.com/test_file.txt"
file_related_id = related_id
return File(
id=str(uuid4()), # Generate new UUID for File.id
tenant_id=tenant_id,
type=FileType.DOCUMENT,
transfer_method=transfer_method,
related_id=file_related_id,
remote_url=remote_url,
filename="test_file.txt",
extension=".txt",
mime_type="text/plain",
size=1024,
storage_key="initial_key",
)
def test_load_storage_keys_local_file(self):
"""Test loading storage keys for LOCAL_FILE transfer method."""
# Create test data
upload_file = self._create_upload_file()
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
# Load storage keys
self.loader.load_storage_keys([file])
# Verify storage key was loaded correctly
assert file._storage_key == upload_file.key
def test_load_storage_keys_remote_url(self):
"""Test loading storage keys for REMOTE_URL transfer method."""
# Create test data
upload_file = self._create_upload_file()
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL)
# Load storage keys
self.loader.load_storage_keys([file])
# Verify storage key was loaded correctly
assert file._storage_key == upload_file.key
def test_load_storage_keys_tool_file(self):
"""Test loading storage keys for TOOL_FILE transfer method."""
# Create test data
tool_file = self._create_tool_file()
file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
# Load storage keys
self.loader.load_storage_keys([file])
# Verify storage key was loaded correctly
assert file._storage_key == tool_file.file_key
def test_load_storage_keys_mixed_methods(self):
"""Test batch loading with mixed transfer methods."""
# Create test data for different transfer methods
upload_file1 = self._create_upload_file()
upload_file2 = self._create_upload_file()
tool_file = self._create_tool_file()
file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE)
file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL)
file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
files = [file1, file2, file3]
# Load storage keys
self.loader.load_storage_keys(files)
# Verify all storage keys were loaded correctly
assert file1._storage_key == upload_file1.key
assert file2._storage_key == upload_file2.key
assert file3._storage_key == tool_file.file_key
def test_load_storage_keys_empty_list(self):
"""Test with empty file list."""
# Should not raise any exceptions
self.loader.load_storage_keys([])
def test_load_storage_keys_tenant_mismatch(self):
"""Test tenant_id validation."""
# Create file with different tenant_id
upload_file = self._create_upload_file()
file = self._create_file(
related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4())
)
# Should raise ValueError for tenant mismatch
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file])
assert "invalid file, expected tenant_id" in str(context.value)
def test_load_storage_keys_missing_file_id(self):
"""Test with None file.related_id."""
# Create a file with valid parameters first, then manually set related_id to None
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
file.related_id = None
# Should raise ValueError for None file related_id
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file])
assert str(context.value) == "file id should not be None."
def test_load_storage_keys_nonexistent_upload_file_records(self):
"""Test with missing UploadFile database records."""
# Create file with non-existent upload file id
non_existent_id = str(uuid4())
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE)
# Should raise ValueError for missing record
with pytest.raises(ValueError):
self.loader.load_storage_keys([file])
def test_load_storage_keys_nonexistent_tool_file_records(self):
"""Test with missing ToolFile database records."""
# Create file with non-existent tool file id
non_existent_id = str(uuid4())
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE)
# Should raise ValueError for missing record
with pytest.raises(ValueError):
self.loader.load_storage_keys([file])
def test_load_storage_keys_invalid_uuid(self):
"""Test with invalid UUID format."""
# Create a file with valid parameters first, then manually set invalid related_id
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
file.related_id = "invalid-uuid-format"
# Should raise ValueError for invalid UUID
with pytest.raises(ValueError):
self.loader.load_storage_keys([file])
def test_load_storage_keys_batch_efficiency(self):
"""Test batched operations use efficient queries."""
# Create multiple files of different types
upload_files = [self._create_upload_file() for _ in range(3)]
tool_files = [self._create_tool_file() for _ in range(2)]
files = []
files.extend(
[self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files]
)
files.extend(
[self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files]
)
# Mock the session to count queries
with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars:
self.loader.load_storage_keys(files)
# Should make exactly 2 queries (one for upload_files, one for tool_files)
assert mock_scalars.call_count == 2
# Verify all storage keys were loaded correctly
for i, file in enumerate(files[:3]):
assert file._storage_key == upload_files[i].key
for i, file in enumerate(files[3:]):
assert file._storage_key == tool_files[i].file_key
def test_load_storage_keys_tenant_isolation(self):
"""Test that tenant isolation works correctly."""
# Create files for different tenants
other_tenant_id = str(uuid4())
# Create upload file for current tenant
upload_file_current = self._create_upload_file()
file_current = self._create_file(
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
)
# Create upload file for other tenant (but don't add to cleanup list)
upload_file_other = UploadFile(
tenant_id=other_tenant_id,
storage_type="local",
key="other_tenant_key",
name="other_file.txt",
size=1024,
extension=".txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
created_at=datetime.now(UTC),
used=False,
)
upload_file_other.id = str(uuid4())
self.session.add(upload_file_other)
self.session.flush()
# Create file for other tenant but try to load with current tenant's loader
file_other = self._create_file(
related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
)
# Should raise ValueError due to tenant mismatch
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file_other])
assert "invalid file, expected tenant_id" in str(context.value)
# Current tenant's file should still work
self.loader.load_storage_keys([file_current])
assert file_current._storage_key == upload_file_current.key
def test_load_storage_keys_mixed_tenant_batch(self):
"""Test batch with mixed tenant files (should fail on first mismatch)."""
# Create files for current tenant
upload_file_current = self._create_upload_file()
file_current = self._create_file(
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
)
# Create file for different tenant
other_tenant_id = str(uuid4())
file_other = self._create_file(
related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
)
# Should raise ValueError on tenant mismatch
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file_current, file_other])
assert "invalid file, expected tenant_id" in str(context.value)
def test_load_storage_keys_duplicate_file_ids(self):
"""Test handling of duplicate file IDs in the batch."""
# Create upload file
upload_file = self._create_upload_file()
# Create two File objects with same related_id
file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
# Should handle duplicates gracefully
self.loader.load_storage_keys([file1, file2])
# Both files should have the same storage key
assert file1._storage_key == upload_file.key
assert file2._storage_key == upload_file.key
def test_load_storage_keys_session_isolation(self):
"""Test that the loader uses the provided session correctly."""
# Create test data
upload_file = self._create_upload_file()
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
# Create loader with different session (same underlying connection)
with Session(bind=db.engine) as other_session:
other_loader = StorageKeyLoader(other_session, self.tenant_id)
with pytest.raises(ValueError):
other_loader.load_storage_keys([file])

View File

@@ -0,0 +1,501 @@
import json
import unittest
import uuid
import pytest
from sqlalchemy.orm import Session
from core.variables.variables import StringVariable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes import NodeType
from factories.variable_factory import build_segment
from libs import datetime_utils
from models import db
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel
from services.workflow_draft_variable_service import DraftVarLoader, VariableResetError, WorkflowDraftVariableService
@pytest.mark.usefixtures("flask_req_ctx")
class TestWorkflowDraftVariableService(unittest.TestCase):
_test_app_id: str
_session: Session
_node1_id = "test_node_1"
_node2_id = "test_node_2"
_node_exec_id = str(uuid.uuid4())
def setUp(self):
self._test_app_id = str(uuid.uuid4())
self._session: Session = db.session()
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=self._test_app_id,
name="sys_var",
value=build_segment("sys_value"),
node_execution_id=self._node_exec_id,
)
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=self._test_app_id,
name="conv_var",
value=build_segment("conv_value"),
)
node2_vars = [
WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id=self._node2_id,
name="int_var",
value=build_segment(1),
visible=False,
node_execution_id=self._node_exec_id,
),
WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id=self._node2_id,
name="str_var",
value=build_segment("str_value"),
visible=True,
node_execution_id=self._node_exec_id,
),
]
node1_var = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id=self._node1_id,
name="str_var",
value=build_segment("str_value"),
visible=True,
node_execution_id=self._node_exec_id,
)
_variables = list(node2_vars)
_variables.extend(
[
node1_var,
sys_var,
conv_var,
]
)
db.session.add_all(_variables)
db.session.flush()
self._variable_ids = [v.id for v in _variables]
self._node1_str_var_id = node1_var.id
self._sys_var_id = sys_var.id
self._conv_var_id = conv_var.id
self._node2_var_ids = [v.id for v in node2_vars]
def _get_test_srv(self) -> WorkflowDraftVariableService:
return WorkflowDraftVariableService(session=self._session)
def tearDown(self):
self._session.rollback()
def test_list_variables(self):
srv = self._get_test_srv()
var_list = srv.list_variables_without_values(self._test_app_id, page=1, limit=2)
assert var_list.total == 5
assert len(var_list.variables) == 2
page1_var_ids = {v.id for v in var_list.variables}
assert page1_var_ids.issubset(self._variable_ids)
var_list_2 = srv.list_variables_without_values(self._test_app_id, page=2, limit=2)
assert var_list_2.total is None
assert len(var_list_2.variables) == 2
page2_var_ids = {v.id for v in var_list_2.variables}
assert page2_var_ids.isdisjoint(page1_var_ids)
assert page2_var_ids.issubset(self._variable_ids)
def test_get_node_variable(self):
srv = self._get_test_srv()
node_var = srv.get_node_variable(self._test_app_id, self._node1_id, "str_var")
assert node_var is not None
assert node_var.id == self._node1_str_var_id
assert node_var.name == "str_var"
assert node_var.get_value() == build_segment("str_value")
def test_get_system_variable(self):
srv = self._get_test_srv()
sys_var = srv.get_system_variable(self._test_app_id, "sys_var")
assert sys_var is not None
assert sys_var.id == self._sys_var_id
assert sys_var.name == "sys_var"
assert sys_var.get_value() == build_segment("sys_value")
def test_get_conversation_variable(self):
srv = self._get_test_srv()
conv_var = srv.get_conversation_variable(self._test_app_id, "conv_var")
assert conv_var is not None
assert conv_var.id == self._conv_var_id
assert conv_var.name == "conv_var"
assert conv_var.get_value() == build_segment("conv_value")
def test_delete_node_variables(self):
srv = self._get_test_srv()
srv.delete_node_variables(self._test_app_id, self._node2_id)
node2_var_count = (
self._session.query(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.app_id == self._test_app_id,
WorkflowDraftVariable.node_id == self._node2_id,
)
.count()
)
assert node2_var_count == 0
def test_delete_variable(self):
srv = self._get_test_srv()
node_1_var = (
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).one()
)
srv.delete_variable(node_1_var)
exists = bool(
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).first()
)
assert exists is False
def test__list_node_variables(self):
srv = self._get_test_srv()
node_vars = srv._list_node_variables(self._test_app_id, self._node2_id)
assert len(node_vars.variables) == 2
assert {v.id for v in node_vars.variables} == set(self._node2_var_ids)
def test_get_draft_variables_by_selectors(self):
srv = self._get_test_srv()
selectors = [
[self._node1_id, "str_var"],
[self._node2_id, "str_var"],
[self._node2_id, "int_var"],
]
variables = srv.get_draft_variables_by_selectors(self._test_app_id, selectors)
assert len(variables) == 3
assert {v.id for v in variables} == {self._node1_str_var_id} | set(self._node2_var_ids)
@pytest.mark.usefixtures("flask_req_ctx")
class TestDraftVariableLoader(unittest.TestCase):
_test_app_id: str
_test_tenant_id: str
_node1_id = "test_loader_node_1"
_node_exec_id = str(uuid.uuid4())
def setUp(self):
self._test_app_id = str(uuid.uuid4())
self._test_tenant_id = str(uuid.uuid4())
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=self._test_app_id,
name="sys_var",
value=build_segment("sys_value"),
node_execution_id=self._node_exec_id,
)
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=self._test_app_id,
name="conv_var",
value=build_segment("conv_value"),
)
node_var = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id=self._node1_id,
name="str_var",
value=build_segment("str_value"),
visible=True,
node_execution_id=self._node_exec_id,
)
_variables = [
node_var,
sys_var,
conv_var,
]
with Session(bind=db.engine, expire_on_commit=False) as session:
session.add_all(_variables)
session.flush()
session.commit()
self._variable_ids = [v.id for v in _variables]
self._node_var_id = node_var.id
self._sys_var_id = sys_var.id
self._conv_var_id = conv_var.id
def tearDown(self):
with Session(bind=db.engine, expire_on_commit=False) as session:
session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.app_id == self._test_app_id).delete(
synchronize_session=False
)
session.commit()
def test_variable_loader_with_empty_selector(self):
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
variables = var_loader.load_variables([])
assert len(variables) == 0
def test_variable_loader_with_non_empty_selector(self):
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
variables = var_loader.load_variables(
[
[SYSTEM_VARIABLE_NODE_ID, "sys_var"],
[CONVERSATION_VARIABLE_NODE_ID, "conv_var"],
[self._node1_id, "str_var"],
]
)
assert len(variables) == 3
conv_var = next(v for v in variables if v.selector[0] == CONVERSATION_VARIABLE_NODE_ID)
assert conv_var.id == self._conv_var_id
sys_var = next(v for v in variables if v.selector[0] == SYSTEM_VARIABLE_NODE_ID)
assert sys_var.id == self._sys_var_id
node1_var = next(v for v in variables if v.selector[0] == self._node1_id)
assert node1_var.id == self._node_var_id
@pytest.mark.usefixtures("flask_req_ctx")
class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
"""Integration tests for reset_variable functionality using real database"""
_test_app_id: str
_test_tenant_id: str
_test_workflow_id: str
_session: Session
_node_id = "test_reset_node"
_node_exec_id: str
_workflow_node_exec_id: str
def setUp(self):
self._test_app_id = str(uuid.uuid4())
self._test_tenant_id = str(uuid.uuid4())
self._test_workflow_id = str(uuid.uuid4())
self._node_exec_id = str(uuid.uuid4())
self._workflow_node_exec_id = str(uuid.uuid4())
self._session: Session = db.session()
# Create a workflow node execution record with outputs
# Note: The WorkflowNodeExecutionModel.id should match the node_execution_id in WorkflowDraftVariable
self._workflow_node_execution = WorkflowNodeExecutionModel(
id=self._node_exec_id, # This should match the node_execution_id in the variable
tenant_id=self._test_tenant_id,
app_id=self._test_app_id,
workflow_id=self._test_workflow_id,
triggered_from="workflow-run",
workflow_run_id=str(uuid.uuid4()),
index=1,
node_execution_id=self._node_exec_id,
node_id=self._node_id,
node_type=NodeType.LLM.value,
title="Test Node",
inputs='{"input": "test input"}',
process_data='{"test_var": "process_value", "other_var": "other_process"}',
outputs='{"test_var": "output_value", "other_var": "other_output"}',
status="succeeded",
elapsed_time=1.5,
created_by_role="account",
created_by=str(uuid.uuid4()),
)
# Create conversation variables for the workflow
self._conv_variables = [
StringVariable(
id=str(uuid.uuid4()),
name="conv_var_1",
description="Test conversation variable 1",
value="default_value_1",
),
StringVariable(
id=str(uuid.uuid4()),
name="conv_var_2",
description="Test conversation variable 2",
value="default_value_2",
),
]
# Create test variables
self._node_var_with_exec = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id=self._node_id,
name="test_var",
value=build_segment("old_value"),
node_execution_id=self._node_exec_id,
)
self._node_var_with_exec.last_edited_at = datetime_utils.naive_utc_now()
self._node_var_without_exec = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id=self._node_id,
name="no_exec_var",
value=build_segment("some_value"),
node_execution_id="temp_exec_id",
)
# Manually set node_execution_id to None after creation
self._node_var_without_exec.node_execution_id = None
self._node_var_missing_exec = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id=self._node_id,
name="missing_exec_var",
value=build_segment("some_value"),
node_execution_id=str(uuid.uuid4()), # Use a valid UUID that doesn't exist in database
)
self._conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=self._test_app_id,
name="conv_var_1",
value=build_segment("old_conv_value"),
)
self._conv_var.last_edited_at = datetime_utils.naive_utc_now()
# Add all to database
db.session.add_all(
[
self._workflow_node_execution,
self._node_var_with_exec,
self._node_var_without_exec,
self._node_var_missing_exec,
self._conv_var,
]
)
db.session.flush()
# Store IDs for assertions
self._node_var_with_exec_id = self._node_var_with_exec.id
self._node_var_without_exec_id = self._node_var_without_exec.id
self._node_var_missing_exec_id = self._node_var_missing_exec.id
self._conv_var_id = self._conv_var.id
def _get_test_srv(self) -> WorkflowDraftVariableService:
return WorkflowDraftVariableService(session=self._session)
def _create_mock_workflow(self) -> Workflow:
"""Create a real workflow with conversation variables and graph"""
conversation_vars = self._conv_variables
# Create a simple graph with the test node
graph = {
"nodes": [{"id": "test_reset_node", "type": "llm", "title": "Test Node", "data": {"type": "llm"}}],
"edges": [],
}
workflow = Workflow.new(
tenant_id=str(uuid.uuid4()),
app_id=self._test_app_id,
type="workflow",
version="1.0",
graph=json.dumps(graph),
features="{}",
created_by=str(uuid.uuid4()),
environment_variables=[],
conversation_variables=conversation_vars,
)
return workflow
def tearDown(self):
self._session.rollback()
def test_reset_node_variable_with_valid_execution_record(self):
"""Test resetting a node variable with valid execution record - should restore from execution"""
srv = self._get_test_srv()
mock_workflow = self._create_mock_workflow()
# Get the variable before reset
variable = srv.get_variable(self._node_var_with_exec_id)
assert variable is not None
assert variable.get_value().value == "old_value"
assert variable.last_edited_at is not None
# Reset the variable
result = srv.reset_variable(mock_workflow, variable)
# Should return the updated variable
assert result is not None
assert result.id == self._node_var_with_exec_id
assert result.node_execution_id == self._workflow_node_execution.id
assert result.last_edited_at is None # Should be reset to None
# The returned variable should have the updated value from execution record
assert result.get_value().value == "output_value"
# Verify the variable was updated in database
updated_variable = srv.get_variable(self._node_var_with_exec_id)
assert updated_variable is not None
# The value should be updated from the execution record's outputs
assert updated_variable.get_value().value == "output_value"
assert updated_variable.last_edited_at is None
assert updated_variable.node_execution_id == self._workflow_node_execution.id
def test_reset_node_variable_with_no_execution_id(self):
"""Test resetting a node variable with no execution ID - should delete variable"""
srv = self._get_test_srv()
mock_workflow = self._create_mock_workflow()
# Get the variable before reset
variable = srv.get_variable(self._node_var_without_exec_id)
assert variable is not None
# Reset the variable
result = srv.reset_variable(mock_workflow, variable)
# Should return None (variable deleted)
assert result is None
# Verify the variable was deleted
deleted_variable = srv.get_variable(self._node_var_without_exec_id)
assert deleted_variable is None
def test_reset_node_variable_with_missing_execution_record(self):
"""Test resetting a node variable when execution record doesn't exist"""
srv = self._get_test_srv()
mock_workflow = self._create_mock_workflow()
# Get the variable before reset
variable = srv.get_variable(self._node_var_missing_exec_id)
assert variable is not None
# Reset the variable
result = srv.reset_variable(mock_workflow, variable)
# Should return None (variable deleted)
assert result is None
# Verify the variable was deleted
deleted_variable = srv.get_variable(self._node_var_missing_exec_id)
assert deleted_variable is None
def test_reset_conversation_variable(self):
"""Test resetting a conversation variable"""
srv = self._get_test_srv()
mock_workflow = self._create_mock_workflow()
# Get the variable before reset
variable = srv.get_variable(self._conv_var_id)
assert variable is not None
assert variable.get_value().value == "old_conv_value"
assert variable.last_edited_at is not None
# Reset the variable
result = srv.reset_variable(mock_workflow, variable)
# Should return the updated variable
assert result is not None
assert result.id == self._conv_var_id
assert result.last_edited_at is None # Should be reset to None
# Verify the variable was updated with default value from workflow
updated_variable = srv.get_variable(self._conv_var_id)
assert updated_variable is not None
# The value should be updated from the workflow's conversation variable default
assert updated_variable.get_value().value == "default_value_1"
assert updated_variable.last_edited_at is None
def test_reset_system_variable_raises_error(self):
"""Test that resetting a system variable raises an error"""
srv = self._get_test_srv()
mock_workflow = self._create_mock_workflow()
# Create a system variable
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=self._test_app_id,
name="sys_var",
value=build_segment("sys_value"),
node_execution_id=self._node_exec_id,
)
db.session.add(sys_var)
db.session.flush()
# Attempt to reset the system variable
with pytest.raises(VariableResetError) as exc_info:
srv.reset_variable(mock_workflow, sys_var)
assert "cannot reset system variable" in str(exc_info.value)
assert sys_var.id in str(exc_info.value)

View File

@@ -8,8 +8,6 @@ from unittest.mock import MagicMock, patch
import pytest
from app_factory import create_app
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage
@@ -30,21 +28,6 @@ from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_mod
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
@pytest.fixture(scope="session")
def app():
# Set up storage configuration
os.environ["STORAGE_TYPE"] = "opendal"
os.environ["OPENDAL_SCHEME"] = "fs"
os.environ["OPENDAL_FS_ROOT"] = "storage"
# Ensure storage directory exists
os.makedirs("storage", exist_ok=True)
app = create_app()
dify_config.LOGIN_DISABLED = True
return app
def init_llm_node(config: dict) -> LLMNode:
graph_config = {
"edges": [
@@ -102,197 +85,195 @@ def init_llm_node(config: dict) -> LLMNode:
return node
def test_execute_llm(app):
with app.app_context():
node = init_llm_node(
config={
"id": "llm",
"data": {
"title": "123",
"type": "llm",
"model": {
"provider": "langgenius/openai/openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {},
},
"prompt_template": [
{
"role": "system",
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.",
},
{"role": "user", "text": "{{#sys.query#}}"},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
def test_execute_llm(flask_req_ctx):
node = init_llm_node(
config={
"id": "llm",
"data": {
"title": "123",
"type": "llm",
"model": {
"provider": "langgenius/openai/openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {},
},
"prompt_template": [
{
"role": "system",
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.",
},
{"role": "user", "text": "{{#sys.query#}}"},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
},
)
},
)
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
# Create a proper LLM result with real entities
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1000"),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
# Create a proper LLM result with real entities
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1000"),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")
mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create a simple mock model config with required attributes
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "langgenius/openai/openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
# Create a simple mock model config with required attributes
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "langgenius/openai/openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config
# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config
# Also mock ModelManager.get_model_instance to avoid database calls
def mock_get_model_instance(_self, **kwargs):
return mock_model_instance
# Also mock ModelManager.get_model_instance to avoid database calls
def mock_get_model_instance(_self, **kwargs):
return mock_model_instance
with (
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
):
# execute node
result = node._run()
assert isinstance(result, Generator)
with (
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
):
# execute node
result = node._run()
assert isinstance(result, Generator)
for item in result:
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert item.run_result.outputs is not None
assert item.run_result.outputs.get("text") is not None
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
for item in result:
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert item.run_result.outputs is not None
assert item.run_result.outputs.get("text") is not None
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
def test_execute_llm_with_jinja2(app, setup_code_executor_mock):
def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
"""
Test execute LLM node with jinja2
"""
with app.app_context():
node = init_llm_node(
config={
"id": "llm",
"data": {
"title": "123",
"type": "llm",
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
"prompt_config": {
"jinja2_variables": [
{"variable": "sys_query", "value_selector": ["sys", "query"]},
{"variable": "output", "value_selector": ["abc", "output"]},
]
},
"prompt_template": [
{
"role": "system",
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
"jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
"edition_type": "jinja2",
},
{
"role": "user",
"text": "{{#sys.query#}}",
"jinja2_text": "{{sys_query}}",
"edition_type": "basic",
},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
node = init_llm_node(
config={
"id": "llm",
"data": {
"title": "123",
"type": "llm",
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
"prompt_config": {
"jinja2_variables": [
{"variable": "sys_query", "value_selector": ["sys", "query"]},
{"variable": "output", "value_selector": ["abc", "output"]},
]
},
"prompt_template": [
{
"role": "system",
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
"jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
"edition_type": "jinja2",
},
{
"role": "user",
"text": "{{#sys.query#}}",
"jinja2_text": "{{sys_query}}",
"edition_type": "basic",
},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
},
)
},
)
# Mock db.session.close()
db.session.close = MagicMock()
# Mock db.session.close()
db.session.close = MagicMock()
# Create a proper LLM result with real entities
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1000"),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
# Create a proper LLM result with real entities
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1000"),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create a simple mock model config with required attributes
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
# Create a simple mock model config with required attributes
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config
# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config
# Also mock ModelManager.get_model_instance to avoid database calls
def mock_get_model_instance(_self, **kwargs):
return mock_model_instance
# Also mock ModelManager.get_model_instance to avoid database calls
def mock_get_model_instance(_self, **kwargs):
return mock_model_instance
with (
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
):
# execute node
result = node._run()
with (
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
):
# execute node
result = node._run()
for item in result:
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert "sunny" in json.dumps(item.run_result.process_data)
assert "what's the weather today?" in json.dumps(item.run_result.process_data)
for item in result:
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert "sunny" in json.dumps(item.run_result.process_data)
assert "what's the weather today?" in json.dumps(item.run_result.process_data)
def test_extract_json():

View File

@@ -0,0 +1,302 @@
import datetime
import uuid
from collections import OrderedDict
from typing import Any, NamedTuple
from flask_restful import marshal
from controllers.console.app.workflow_draft_variable import (
_WORKFLOW_DRAFT_VARIABLE_FIELDS,
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS,
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS,
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
)
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from factories.variable_factory import build_segment
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList
_TEST_APP_ID = "test_app_id"
_TEST_NODE_EXEC_ID = str(uuid.uuid4())
class TestWorkflowDraftVariableFields:
def test_conversation_variable(self):
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1)
)
conv_var.id = str(uuid.uuid4())
conv_var.visible = True
expected_without_value: OrderedDict[str, Any] = OrderedDict(
{
"id": str(conv_var.id),
"type": conv_var.get_variable_type().value,
"name": "conv_var",
"description": "",
"selector": [CONVERSATION_VARIABLE_NODE_ID, "conv_var"],
"value_type": "number",
"edited": False,
"visible": True,
}
)
assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
expected_with_value = expected_without_value.copy()
expected_with_value["value"] = 1
assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
def test_create_sys_variable(self):
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=_TEST_APP_ID,
name="sys_var",
value=build_segment("a"),
editable=True,
node_execution_id=_TEST_NODE_EXEC_ID,
)
sys_var.id = str(uuid.uuid4())
sys_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
sys_var.visible = True
expected_without_value = OrderedDict(
{
"id": str(sys_var.id),
"type": sys_var.get_variable_type().value,
"name": "sys_var",
"description": "",
"selector": [SYSTEM_VARIABLE_NODE_ID, "sys_var"],
"value_type": "string",
"edited": True,
"visible": True,
}
)
assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
expected_with_value = expected_without_value.copy()
expected_with_value["value"] = "a"
assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
def test_node_variable(self):
node_var = WorkflowDraftVariable.new_node_variable(
app_id=_TEST_APP_ID,
node_id="test_node",
name="node_var",
value=build_segment([1, "a"]),
visible=False,
node_execution_id=_TEST_NODE_EXEC_ID,
)
node_var.id = str(uuid.uuid4())
node_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
expected_without_value: OrderedDict[str, Any] = OrderedDict(
{
"id": str(node_var.id),
"type": node_var.get_variable_type().value,
"name": "node_var",
"description": "",
"selector": ["test_node", "node_var"],
"value_type": "array[any]",
"edited": True,
"visible": False,
}
)
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
expected_with_value = expected_without_value.copy()
expected_with_value["value"] = [1, "a"]
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
class TestWorkflowDraftVariableList:
def test_workflow_draft_variable_list(self):
class TestCase(NamedTuple):
name: str
var_list: WorkflowDraftVariableList
expected: dict
node_var = WorkflowDraftVariable.new_node_variable(
app_id=_TEST_APP_ID,
node_id="test_node",
name="test_var",
value=build_segment("a"),
visible=True,
node_execution_id=_TEST_NODE_EXEC_ID,
)
node_var.id = str(uuid.uuid4())
node_var_dict = OrderedDict(
{
"id": str(node_var.id),
"type": node_var.get_variable_type().value,
"name": "test_var",
"description": "",
"selector": ["test_node", "test_var"],
"value_type": "string",
"edited": False,
"visible": True,
}
)
cases = [
TestCase(
name="empty variable list",
var_list=WorkflowDraftVariableList(variables=[]),
expected=OrderedDict(
{
"items": [],
"total": None,
}
),
),
TestCase(
name="empty variable list with total",
var_list=WorkflowDraftVariableList(variables=[], total=10),
expected=OrderedDict(
{
"items": [],
"total": 10,
}
),
),
TestCase(
name="non-empty variable list",
var_list=WorkflowDraftVariableList(variables=[node_var], total=None),
expected=OrderedDict(
{
"items": [node_var_dict],
"total": None,
}
),
),
TestCase(
name="non-empty variable list with total",
var_list=WorkflowDraftVariableList(variables=[node_var], total=10),
expected=OrderedDict(
{
"items": [node_var_dict],
"total": 10,
}
),
),
]
for idx, case in enumerate(cases, 1):
assert marshal(case.var_list, _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) == case.expected, (
f"Test case {idx} failed, {case.name=}"
)
def test_workflow_node_variables_fields():
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1)
)
resp = marshal(WorkflowDraftVariableList(variables=[conv_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
assert isinstance(resp, dict)
assert len(resp["items"]) == 1
item_dict = resp["items"][0]
assert item_dict["name"] == "conv_var"
assert item_dict["value"] == 1
def test_workflow_file_variable_with_signed_url():
"""Test that File type variables include signed URLs in API responses."""
from core.file.enums import FileTransferMethod, FileType
from core.file.models import File
# Create a File object with LOCAL_FILE transfer method (which generates signed URLs)
test_file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test_upload_file_id",
filename="test.jpg",
extension=".jpg",
mime_type="image/jpeg",
size=12345,
)
# Create a WorkflowDraftVariable with the File
file_var = WorkflowDraftVariable.new_node_variable(
app_id=_TEST_APP_ID,
node_id="test_node",
name="file_var",
value=build_segment(test_file),
node_execution_id=_TEST_NODE_EXEC_ID,
)
# Marshal the variable using the API fields
resp = marshal(WorkflowDraftVariableList(variables=[file_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
# Verify the response structure
assert isinstance(resp, dict)
assert len(resp["items"]) == 1
item_dict = resp["items"][0]
assert item_dict["name"] == "file_var"
# Verify the value is a dict (File.to_dict() result) and contains expected fields
value = item_dict["value"]
assert isinstance(value, dict)
# Verify the File fields are preserved
assert value["id"] == test_file.id
assert value["filename"] == test_file.filename
assert value["type"] == test_file.type.value
assert value["transfer_method"] == test_file.transfer_method.value
assert value["size"] == test_file.size
# Verify the URL is present (it should be a signed URL for LOCAL_FILE transfer method)
remote_url = value["remote_url"]
assert remote_url is not None
assert isinstance(remote_url, str)
# For LOCAL_FILE, the URL should contain signature parameters
assert "timestamp=" in remote_url
assert "nonce=" in remote_url
assert "sign=" in remote_url
def test_workflow_file_variable_remote_url():
"""Test that File type variables with REMOTE_URL transfer method return the remote URL."""
from core.file.enums import FileTransferMethod, FileType
from core.file.models import File
# Create a File object with REMOTE_URL transfer method
test_file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/test.jpg",
filename="test.jpg",
extension=".jpg",
mime_type="image/jpeg",
size=12345,
)
# Create a WorkflowDraftVariable with the File
file_var = WorkflowDraftVariable.new_node_variable(
app_id=_TEST_APP_ID,
node_id="test_node",
name="file_var",
value=build_segment(test_file),
node_execution_id=_TEST_NODE_EXEC_ID,
)
# Marshal the variable using the API fields
resp = marshal(WorkflowDraftVariableList(variables=[file_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
# Verify the response structure
assert isinstance(resp, dict)
assert len(resp["items"]) == 1
item_dict = resp["items"][0]
assert item_dict["name"] == "file_var"
# Verify the value is a dict (File.to_dict() result) and contains expected fields
value = item_dict["value"]
assert isinstance(value, dict)
remote_url = value["remote_url"]
# For REMOTE_URL, the URL should be the original remote URL
assert remote_url == test_file.remote_url

View File

@@ -1,165 +0,0 @@
from uuid import uuid4
import pytest
from core.variables import (
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FloatVariable,
IntegerVariable,
ObjectSegment,
SecretVariable,
StringVariable,
)
from core.variables.exc import VariableError
from core.variables.segments import ArrayAnySegment
from factories import variable_factory
def test_string_variable():
test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, StringVariable)
def test_integer_variable():
test_data = {"value_type": "number", "name": "test_int", "value": 42}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, IntegerVariable)
def test_float_variable():
test_data = {"value_type": "number", "name": "test_float", "value": 3.14}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, FloatVariable)
def test_secret_variable():
test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, SecretVariable)
def test_invalid_value_type():
test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"}
with pytest.raises(VariableError):
variable_factory.build_conversation_variable_from_mapping(test_data)
def test_build_a_blank_string():
result = variable_factory.build_conversation_variable_from_mapping(
{
"value_type": "string",
"name": "blank",
"value": "",
}
)
assert isinstance(result, StringVariable)
assert result.value == ""
def test_build_a_object_variable_with_none_value():
var = variable_factory.build_segment(
{
"key1": None,
}
)
assert isinstance(var, ObjectSegment)
assert var.value["key1"] is None
def test_object_variable():
mapping = {
"id": str(uuid4()),
"value_type": "object",
"name": "test_object",
"description": "Description of the variable.",
"value": {
"key1": "text",
"key2": 2,
},
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ObjectSegment)
assert isinstance(variable.value["key1"], str)
assert isinstance(variable.value["key2"], int)
def test_array_string_variable():
mapping = {
"id": str(uuid4()),
"value_type": "array[string]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
"text",
"text",
],
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ArrayStringVariable)
assert isinstance(variable.value[0], str)
assert isinstance(variable.value[1], str)
def test_array_number_variable():
mapping = {
"id": str(uuid4()),
"value_type": "array[number]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
1,
2.0,
],
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ArrayNumberVariable)
assert isinstance(variable.value[0], int)
assert isinstance(variable.value[1], float)
def test_array_object_variable():
mapping = {
"id": str(uuid4()),
"value_type": "array[object]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
{
"key1": "text",
"key2": 1,
},
{
"key1": "text",
"key2": 1,
},
],
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ArrayObjectVariable)
assert isinstance(variable.value[0], dict)
assert isinstance(variable.value[1], dict)
assert isinstance(variable.value[0]["key1"], str)
assert isinstance(variable.value[0]["key2"], int)
assert isinstance(variable.value[1]["key1"], str)
assert isinstance(variable.value[1]["key2"], int)
def test_variable_cannot_large_than_200_kb():
with pytest.raises(VariableError):
variable_factory.build_conversation_variable_from_mapping(
{
"id": str(uuid4()),
"value_type": "string",
"name": "test_text",
"value": "a" * 1024 * 201,
}
)
def test_array_none_variable():
var = variable_factory.build_segment([None, None, None, None])
assert isinstance(var, ArrayAnySegment)
assert var.value == [None, None, None, None]

View File

@@ -0,0 +1,25 @@
from core.file import File, FileTransferMethod, FileType
def test_file():
file = File(
id="test-file",
tenant_id="test-tenant-id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="test-related-id",
filename="image.png",
extension=".png",
mime_type="image/png",
size=67,
storage_key="test-storage-key",
url="https://example.com/image.png",
)
assert file.tenant_id == "test-tenant-id"
assert file.type == FileType.IMAGE
assert file.transfer_method == FileTransferMethod.TOOL_FILE
assert file.related_id == "test-related-id"
assert file.filename == "image.png"
assert file.extension == ".png"
assert file.mime_type == "image/png"
assert file.size == 67

View File

@@ -3,6 +3,7 @@ import uuid
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables.segments import ArrayAnySegment, ArrayStringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@@ -197,7 +198,7 @@ def test_run():
count += 1
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
assert count == 20
@@ -413,7 +414,7 @@ def test_run_parallel():
count += 1
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
assert count == 32
@@ -654,7 +655,7 @@ def test_iteration_run_in_parallel_mode():
parallel_arr.append(item)
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
assert count == 32
for item in sequential_result:
@@ -662,7 +663,7 @@ def test_iteration_run_in_parallel_mode():
count += 1
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
assert count == 64
@@ -846,7 +847,7 @@ def test_iteration_run_error_handle():
count += 1
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": [None, None]}
assert item.run_result.outputs == {"output": ArrayAnySegment(value=[None, None])}
assert count == 14
# execute remove abnormal output
@@ -857,5 +858,5 @@ def test_iteration_run_error_handle():
count += 1
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": []}
assert item.run_result.outputs == {"output": ArrayAnySegment(value=[])}
assert count == 14

View File

@@ -7,6 +7,7 @@ from docx.oxml.text.paragraph import CT_P
from core.file import File, FileTransferMethod
from core.variables import ArrayFileSegment
from core.variables.segments import ArrayStringSegment
from core.variables.variables import StringVariable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@@ -69,7 +70,13 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s
@pytest.mark.parametrize(
("mime_type", "file_content", "expected_text", "transfer_method", "extension"),
[
("text/plain", b"Hello, world!", ["Hello, world!"], FileTransferMethod.LOCAL_FILE, ".txt"),
(
"text/plain",
b"Hello, world!",
["Hello, world!"],
FileTransferMethod.LOCAL_FILE,
".txt",
),
(
"application/pdf",
b"%PDF-1.5\n%Test PDF content",
@@ -84,7 +91,13 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s
FileTransferMethod.REMOTE_URL,
"",
),
("text/plain", b"Remote content", ["Remote content"], FileTransferMethod.REMOTE_URL, None),
(
"text/plain",
b"Remote content",
["Remote content"],
FileTransferMethod.REMOTE_URL,
None,
),
],
)
def test_run_extract_text(
@@ -131,7 +144,7 @@ def test_run_extract_text(
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error
assert result.outputs is not None
assert result.outputs["text"] == expected_text
assert result.outputs["text"] == ArrayStringSegment(value=expected_text)
if transfer_method == FileTransferMethod.REMOTE_URL:
mock_ssrf_proxy_get.assert_called_once_with("https://example.com/file.txt")

View File

@@ -115,7 +115,7 @@ def test_filter_files_by_type(list_operator_node):
},
]
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
for expected_file, result_file in zip(expected_files, result.outputs["result"]):
for expected_file, result_file in zip(expected_files, result.outputs["result"].value):
assert expected_file["filename"] == result_file.filename
assert expected_file["type"] == result_file.type
assert expected_file["tenant_id"] == result_file.tenant_id

View File

@@ -5,6 +5,7 @@ from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import ArrayStringVariable, StringVariable
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
@@ -63,10 +64,11 @@ def test_overwrite_string_variable():
name="test_string_variable",
value="the second value",
)
conversation_id = str(uuid.uuid4())
# construct variable pool
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -77,6 +79,9 @@ def test_overwrite_string_variable():
input_variable,
)
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
@@ -91,11 +96,20 @@ def test_overwrite_string_variable():
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
},
conv_var_updater_factory=mock_conv_var_updater_factory,
)
with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run:
list(node.run())
mock_run.assert_called_once()
list(node.run())
expected_var = StringVariable(
id=conversation_variable.id,
name=conversation_variable.name,
description=conversation_variable.description,
selector=conversation_variable.selector,
value_type=conversation_variable.value_type,
value=input_variable.value,
)
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
mock_conv_var_updater.flush.assert_called_once()
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
@@ -148,9 +162,10 @@ def test_append_variable_to_array():
name="test_string_variable",
value="the second value",
)
conversation_id = str(uuid.uuid4())
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -160,6 +175,9 @@ def test_append_variable_to_array():
input_variable,
)
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
@@ -174,11 +192,22 @@ def test_append_variable_to_array():
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
},
conv_var_updater_factory=mock_conv_var_updater_factory,
)
with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run:
list(node.run())
mock_run.assert_called_once()
list(node.run())
expected_value = list(conversation_variable.value)
expected_value.append(input_variable.value)
expected_var = ArrayStringVariable(
id=conversation_variable.id,
name=conversation_variable.name,
description=conversation_variable.description,
selector=conversation_variable.selector,
value_type=conversation_variable.value_type,
value=expected_value,
)
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
mock_conv_var_updater.flush.assert_called_once()
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
@@ -225,13 +254,17 @@ def test_clear_array():
value=["the first value"],
)
conversation_id = str(uuid.uuid4())
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
@@ -246,11 +279,20 @@ def test_clear_array():
"input_variable_selector": [],
},
},
conv_var_updater_factory=mock_conv_var_updater_factory,
)
with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run:
list(node.run())
mock_run.assert_called_once()
list(node.run())
expected_var = ArrayStringVariable(
id=conversation_variable.id,
name=conversation_variable.name,
description=conversation_variable.description,
selector=conversation_variable.selector,
value_type=conversation_variable.value_type,
value=[],
)
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
mock_conv_var_updater.flush.assert_called_once()
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None

View File

@@ -1,8 +1,12 @@
import pytest
from pydantic import ValidationError
from core.file import File, FileTransferMethod, FileType
from core.variables import FileSegment, StringSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from factories.variable_factory import build_segment, segment_to_variable
@pytest.fixture
@@ -44,3 +48,38 @@ def test_use_long_selector(pool):
result = pool.get(("node_1", "part_1", "part_2"))
assert result is not None
assert result.value == "test_value"
class TestVariablePool:
def test_constructor(self):
pool = VariablePool()
pool = VariablePool(
variable_dictionary={},
user_inputs={},
system_variables={},
environment_variables=[],
conversation_variables=[],
)
pool = VariablePool(
user_inputs={"key": "value"},
system_variables={SystemVariableKey.WORKFLOW_ID: "test_workflow_id"},
environment_variables=[
segment_to_variable(
segment=build_segment(1),
selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_var_1"],
name="env_var_1",
)
],
conversation_variables=[
segment_to_variable(
segment=build_segment("1"),
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var_1"],
name="conv_var_1",
)
],
)
def test_constructor_with_invalid_system_variable_key(self):
with pytest.raises(ValidationError):
VariablePool(system_variables={"invalid_key": "value"}) # type: ignore

View File

@@ -1,22 +1,10 @@
from core.variables import SecretVariable
import dataclasses
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.utils import variable_template_parser
def test_extract_selectors_from_template():
variable_pool = VariablePool(
system_variables={
SystemVariableKey("user_id"): "fake-user-id",
},
user_inputs={},
environment_variables=[
SecretVariable(name="secret_key", value="fake-secret-key"),
],
conversation_variables=[],
)
variable_pool.add(("node_id", "custom_query"), "fake-user-query")
template = (
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
)
@@ -26,3 +14,35 @@ def test_extract_selectors_from_template():
VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]),
VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]),
]
def test_invalid_references():
@dataclasses.dataclass
class TestCase:
name: str
template: str
cases = [
TestCase(
name="lack of closing brace",
template="Hello, {{#sys.user_id#",
),
TestCase(
name="lack of opening brace",
template="Hello, #sys.user_id#}}",
),
TestCase(
name="lack selector name",
template="Hello, {{#sys#}}",
),
TestCase(
name="empty node name part",
template="Hello, {{#.user_id#}}",
),
]
for idx, c in enumerate(cases, 1):
fail_msg = f"Test case {c.name} failed, index={idx}"
selectors = variable_template_parser.extract_selectors_from_template(c.template)
assert selectors == [], fail_msg
parser = variable_template_parser.VariableTemplateParser(c.template)
assert parser.extract_variable_selectors() == [], fail_msg

View File

@@ -0,0 +1,865 @@
import math
from dataclasses import dataclass
from typing import Any
from uuid import uuid4
import pytest
from hypothesis import given
from hypothesis import strategies as st
from core.file import File, FileTransferMethod, FileType
from core.variables import (
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FloatVariable,
IntegerVariable,
ObjectSegment,
SecretVariable,
SegmentType,
StringVariable,
)
from core.variables.exc import VariableError
from core.variables.segments import (
ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
StringSegment,
)
from core.variables.types import SegmentType
from factories import variable_factory
from factories.variable_factory import TypeMismatchError, build_segment_with_type
def test_string_variable():
test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, StringVariable)
def test_integer_variable():
test_data = {"value_type": "number", "name": "test_int", "value": 42}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, IntegerVariable)
def test_float_variable():
test_data = {"value_type": "number", "name": "test_float", "value": 3.14}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, FloatVariable)
def test_secret_variable():
test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"}
result = variable_factory.build_conversation_variable_from_mapping(test_data)
assert isinstance(result, SecretVariable)
def test_invalid_value_type():
test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"}
with pytest.raises(VariableError):
variable_factory.build_conversation_variable_from_mapping(test_data)
def test_build_a_blank_string():
result = variable_factory.build_conversation_variable_from_mapping(
{
"value_type": "string",
"name": "blank",
"value": "",
}
)
assert isinstance(result, StringVariable)
assert result.value == ""
def test_build_a_object_variable_with_none_value():
var = variable_factory.build_segment(
{
"key1": None,
}
)
assert isinstance(var, ObjectSegment)
assert var.value["key1"] is None
def test_object_variable():
mapping = {
"id": str(uuid4()),
"value_type": "object",
"name": "test_object",
"description": "Description of the variable.",
"value": {
"key1": "text",
"key2": 2,
},
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ObjectSegment)
assert isinstance(variable.value["key1"], str)
assert isinstance(variable.value["key2"], int)
def test_array_string_variable():
mapping = {
"id": str(uuid4()),
"value_type": "array[string]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
"text",
"text",
],
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ArrayStringVariable)
assert isinstance(variable.value[0], str)
assert isinstance(variable.value[1], str)
def test_array_number_variable():
mapping = {
"id": str(uuid4()),
"value_type": "array[number]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
1,
2.0,
],
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ArrayNumberVariable)
assert isinstance(variable.value[0], int)
assert isinstance(variable.value[1], float)
def test_array_object_variable():
mapping = {
"id": str(uuid4()),
"value_type": "array[object]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
{
"key1": "text",
"key2": 1,
},
{
"key1": "text",
"key2": 1,
},
],
}
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
assert isinstance(variable, ArrayObjectVariable)
assert isinstance(variable.value[0], dict)
assert isinstance(variable.value[1], dict)
assert isinstance(variable.value[0]["key1"], str)
assert isinstance(variable.value[0]["key2"], int)
assert isinstance(variable.value[1]["key1"], str)
assert isinstance(variable.value[1]["key2"], int)
def test_variable_cannot_large_than_200_kb():
with pytest.raises(VariableError):
variable_factory.build_conversation_variable_from_mapping(
{
"id": str(uuid4()),
"value_type": "string",
"name": "test_text",
"value": "a" * 1024 * 201,
}
)
def test_array_none_variable():
var = variable_factory.build_segment([None, None, None, None])
assert isinstance(var, ArrayAnySegment)
assert var.value == [None, None, None, None]
def test_build_segment_none_type():
"""Test building NoneSegment from None value."""
segment = variable_factory.build_segment(None)
assert isinstance(segment, NoneSegment)
assert segment.value is None
assert segment.value_type == SegmentType.NONE
def test_build_segment_none_type_properties():
"""Test NoneSegment properties and methods."""
segment = variable_factory.build_segment(None)
assert segment.text == ""
assert segment.log == ""
assert segment.markdown == ""
assert segment.to_object() is None
def test_build_segment_array_file_single_file():
"""Test building ArrayFileSegment from list with single file."""
file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://test.example.com/test-file.png",
filename="test-file",
extension=".png",
mime_type="image/png",
size=1000,
)
segment = variable_factory.build_segment([file])
assert isinstance(segment, ArrayFileSegment)
assert len(segment.value) == 1
assert segment.value[0] == file
assert segment.value_type == SegmentType.ARRAY_FILE
def test_build_segment_array_file_multiple_files():
"""Test building ArrayFileSegment from list with multiple files."""
file1 = File(
id="test_file_id_1",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://test.example.com/test-file1.png",
filename="test-file1",
extension=".png",
mime_type="image/png",
size=1000,
)
file2 = File(
id="test_file_id_2",
tenant_id="test_tenant_id",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test_relation_id",
filename="test-file2",
extension=".txt",
mime_type="text/plain",
size=500,
)
segment = variable_factory.build_segment([file1, file2])
assert isinstance(segment, ArrayFileSegment)
assert len(segment.value) == 2
assert segment.value[0] == file1
assert segment.value[1] == file2
assert segment.value_type == SegmentType.ARRAY_FILE
def test_build_segment_array_file_empty_list():
"""Test building ArrayFileSegment from empty list should create ArrayAnySegment."""
segment = variable_factory.build_segment([])
assert isinstance(segment, ArrayAnySegment)
assert segment.value == []
assert segment.value_type == SegmentType.ARRAY_ANY
def test_build_segment_array_any_mixed_types():
"""Test building ArrayAnySegment from list with mixed types."""
mixed_values = ["string", 42, 3.14, {"key": "value"}, None]
segment = variable_factory.build_segment(mixed_values)
assert isinstance(segment, ArrayAnySegment)
assert segment.value == mixed_values
assert segment.value_type == SegmentType.ARRAY_ANY
def test_build_segment_array_any_with_nested_arrays():
"""Test building ArrayAnySegment from list containing arrays."""
nested_values = [["nested", "array"], [1, 2, 3], "string"]
segment = variable_factory.build_segment(nested_values)
assert isinstance(segment, ArrayAnySegment)
assert segment.value == nested_values
assert segment.value_type == SegmentType.ARRAY_ANY
def test_build_segment_array_any_mixed_with_files():
"""Test building ArrayAnySegment from list with files and other types."""
file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://test.example.com/test-file.png",
filename="test-file",
extension=".png",
mime_type="image/png",
size=1000,
)
mixed_values = [file, "string", 42]
segment = variable_factory.build_segment(mixed_values)
assert isinstance(segment, ArrayAnySegment)
assert segment.value == mixed_values
assert segment.value_type == SegmentType.ARRAY_ANY
def test_build_segment_array_any_all_none_values():
"""Test building ArrayAnySegment from list with all None values."""
none_values = [None, None, None]
segment = variable_factory.build_segment(none_values)
assert isinstance(segment, ArrayAnySegment)
assert segment.value == none_values
assert segment.value_type == SegmentType.ARRAY_ANY
def test_build_segment_array_file_properties():
"""Test ArrayFileSegment properties and methods."""
file1 = File(
id="test_file_id_1",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://test.example.com/test-file1.png",
filename="test-file1",
extension=".png",
mime_type="image/png",
size=1000,
)
file2 = File(
id="test_file_id_2",
tenant_id="test_tenant_id",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://test.example.com/test-file2.txt",
filename="test-file2",
extension=".txt",
mime_type="text/plain",
size=500,
)
segment = variable_factory.build_segment([file1, file2])
# Test properties
assert segment.text == "" # ArrayFileSegment text property returns empty string
assert segment.log == "" # ArrayFileSegment log property returns empty string
assert segment.markdown == f"{file1.markdown}\n{file2.markdown}"
assert segment.to_object() == [file1, file2]
def test_build_segment_array_any_properties():
"""Test ArrayAnySegment properties and methods."""
mixed_values = ["string", 42, None]
segment = variable_factory.build_segment(mixed_values)
# Test properties
assert segment.text == str(mixed_values)
assert segment.log == str(mixed_values)
assert segment.markdown == "string\n42\nNone"
assert segment.to_object() == mixed_values
def test_build_segment_edge_cases():
"""Test edge cases for build_segment function."""
# Test with complex nested structures
complex_structure = [{"nested": {"deep": [1, 2, 3]}}, [{"inner": "value"}], "mixed"]
segment = variable_factory.build_segment(complex_structure)
assert isinstance(segment, ArrayAnySegment)
assert segment.value == complex_structure
# Test with single None in list
single_none = [None]
segment = variable_factory.build_segment(single_none)
assert isinstance(segment, ArrayAnySegment)
assert segment.value == single_none
def test_build_segment_file_array_with_different_file_types():
"""Test ArrayFileSegment with different file types."""
image_file = File(
id="image_id",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://test.example.com/image.png",
filename="image",
extension=".png",
mime_type="image/png",
size=1000,
)
video_file = File(
id="video_id",
tenant_id="test_tenant_id",
type=FileType.VIDEO,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="video_relation_id",
filename="video",
extension=".mp4",
mime_type="video/mp4",
size=5000,
)
audio_file = File(
id="audio_id",
tenant_id="test_tenant_id",
type=FileType.AUDIO,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="audio_relation_id",
filename="audio",
extension=".mp3",
mime_type="audio/mpeg",
size=3000,
)
segment = variable_factory.build_segment([image_file, video_file, audio_file])
assert isinstance(segment, ArrayFileSegment)
assert len(segment.value) == 3
assert segment.value[0].type == FileType.IMAGE
assert segment.value[1].type == FileType.VIDEO
assert segment.value[2].type == FileType.AUDIO
@st.composite
def _generate_file(draw) -> File:
file_id = draw(st.text(min_size=1, max_size=10))
tenant_id = draw(st.text(min_size=1, max_size=10))
file_type, mime_type, extension = draw(
st.sampled_from(
[
(FileType.IMAGE, "image/png", ".png"),
(FileType.VIDEO, "video/mp4", ".mp4"),
(FileType.DOCUMENT, "text/plain", ".txt"),
(FileType.AUDIO, "audio/mpeg", ".mp3"),
]
)
)
filename = "test-file"
size = draw(st.integers(min_value=0, max_value=1024 * 1024))
transfer_method = draw(st.sampled_from(list(FileTransferMethod)))
if transfer_method == FileTransferMethod.REMOTE_URL:
url = "https://test.example.com/test-file"
file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=file_type,
transfer_method=transfer_method,
remote_url=url,
related_id=None,
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
)
else:
relation_id = draw(st.uuids(version=4))
file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=file_type,
transfer_method=transfer_method,
related_id=str(relation_id),
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
)
return file
def _scalar_value() -> st.SearchStrategy[int | float | str | File | None]:
return st.one_of(
st.none(),
st.integers(),
st.floats(),
st.text(),
_generate_file(),
)
@given(_scalar_value())
def test_build_segment_and_extract_values_for_scalar_types(value):
seg = variable_factory.build_segment(value)
# nan == nan yields false, so we need to use `math.isnan` to check `seg.value` here.
if isinstance(value, float) and math.isnan(value):
assert math.isnan(seg.value)
else:
assert seg.value == value
@given(st.lists(_scalar_value()))
def test_build_segment_and_extract_values_for_array_types(values):
seg = variable_factory.build_segment(values)
assert seg.value == values
def test_build_segment_type_for_scalar():
@dataclass(frozen=True)
class TestCase:
value: int | float | str | File
expected_type: SegmentType
file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://test.example.com/test-file.png",
filename="test-file",
extension=".png",
mime_type="image/png",
size=1000,
)
cases = [
TestCase(0, SegmentType.NUMBER),
TestCase(0.0, SegmentType.NUMBER),
TestCase("", SegmentType.STRING),
TestCase(file, SegmentType.FILE),
]
for idx, c in enumerate(cases, 1):
segment = variable_factory.build_segment(c.value)
assert segment.value_type == c.expected_type, f"test case {idx} failed."
class TestBuildSegmentWithType:
"""Test cases for build_segment_with_type function."""
def test_string_type(self):
"""Test building a string segment with correct type."""
result = build_segment_with_type(SegmentType.STRING, "hello")
assert isinstance(result, StringSegment)
assert result.value == "hello"
assert result.value_type == SegmentType.STRING
def test_number_type_integer(self):
"""Test building a number segment with integer value."""
result = build_segment_with_type(SegmentType.NUMBER, 42)
assert isinstance(result, IntegerSegment)
assert result.value == 42
assert result.value_type == SegmentType.NUMBER
def test_number_type_float(self):
"""Test building a number segment with float value."""
result = build_segment_with_type(SegmentType.NUMBER, 3.14)
assert isinstance(result, FloatSegment)
assert result.value == 3.14
assert result.value_type == SegmentType.NUMBER
def test_object_type(self):
"""Test building an object segment with correct type."""
test_obj = {"key": "value", "nested": {"inner": 123}}
result = build_segment_with_type(SegmentType.OBJECT, test_obj)
assert isinstance(result, ObjectSegment)
assert result.value == test_obj
assert result.value_type == SegmentType.OBJECT
def test_file_type(self):
"""Test building a file segment with correct type."""
test_file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://test.example.com/test-file.png",
filename="test-file",
extension=".png",
mime_type="image/png",
size=1000,
storage_key="test_storage_key",
)
result = build_segment_with_type(SegmentType.FILE, test_file)
assert isinstance(result, FileSegment)
assert result.value == test_file
assert result.value_type == SegmentType.FILE
def test_none_type(self):
"""Test building a none segment with None value."""
result = build_segment_with_type(SegmentType.NONE, None)
assert isinstance(result, NoneSegment)
assert result.value is None
assert result.value_type == SegmentType.NONE
def test_empty_array_string(self):
"""Test building an empty array[string] segment."""
result = build_segment_with_type(SegmentType.ARRAY_STRING, [])
assert isinstance(result, ArrayStringSegment)
assert result.value == []
assert result.value_type == SegmentType.ARRAY_STRING
def test_empty_array_number(self):
"""Test building an empty array[number] segment."""
result = build_segment_with_type(SegmentType.ARRAY_NUMBER, [])
assert isinstance(result, ArrayNumberSegment)
assert result.value == []
assert result.value_type == SegmentType.ARRAY_NUMBER
def test_empty_array_object(self):
"""Test building an empty array[object] segment."""
result = build_segment_with_type(SegmentType.ARRAY_OBJECT, [])
assert isinstance(result, ArrayObjectSegment)
assert result.value == []
assert result.value_type == SegmentType.ARRAY_OBJECT
def test_empty_array_file(self):
"""Test building an empty array[file] segment."""
result = build_segment_with_type(SegmentType.ARRAY_FILE, [])
assert isinstance(result, ArrayFileSegment)
assert result.value == []
assert result.value_type == SegmentType.ARRAY_FILE
def test_empty_array_any(self):
"""Test building an empty array[any] segment."""
result = build_segment_with_type(SegmentType.ARRAY_ANY, [])
assert isinstance(result, ArrayAnySegment)
assert result.value == []
assert result.value_type == SegmentType.ARRAY_ANY
def test_array_with_values(self):
"""Test building array segments with actual values."""
# Array of strings
result = build_segment_with_type(SegmentType.ARRAY_STRING, ["hello", "world"])
assert isinstance(result, ArrayStringSegment)
assert result.value == ["hello", "world"]
assert result.value_type == SegmentType.ARRAY_STRING
# Array of numbers
result = build_segment_with_type(SegmentType.ARRAY_NUMBER, [1, 2, 3.14])
assert isinstance(result, ArrayNumberSegment)
assert result.value == [1, 2, 3.14]
assert result.value_type == SegmentType.ARRAY_NUMBER
# Array of objects
result = build_segment_with_type(SegmentType.ARRAY_OBJECT, [{"a": 1}, {"b": 2}])
assert isinstance(result, ArrayObjectSegment)
assert result.value == [{"a": 1}, {"b": 2}]
assert result.value_type == SegmentType.ARRAY_OBJECT
def test_type_mismatch_string_to_number(self):
"""Test type mismatch when expecting number but getting string."""
with pytest.raises(TypeMismatchError) as exc_info:
build_segment_with_type(SegmentType.NUMBER, "not_a_number")
assert "Type mismatch" in str(exc_info.value)
assert "expected number" in str(exc_info.value)
assert "str" in str(exc_info.value)
def test_type_mismatch_number_to_string(self):
"""Test type mismatch when expecting string but getting number."""
with pytest.raises(TypeMismatchError) as exc_info:
build_segment_with_type(SegmentType.STRING, 123)
assert "Type mismatch" in str(exc_info.value)
assert "expected string" in str(exc_info.value)
assert "int" in str(exc_info.value)
def test_type_mismatch_none_to_string(self):
"""Test type mismatch when expecting string but getting None."""
with pytest.raises(TypeMismatchError) as exc_info:
build_segment_with_type(SegmentType.STRING, None)
assert "Expected string, but got None" in str(exc_info.value)
def test_type_mismatch_empty_list_to_non_array(self):
"""Test type mismatch when expecting non-array type but getting empty list."""
with pytest.raises(TypeMismatchError) as exc_info:
build_segment_with_type(SegmentType.STRING, [])
assert "Expected string, but got empty list" in str(exc_info.value)
def test_type_mismatch_object_to_array(self):
"""Test type mismatch when expecting array but getting object."""
with pytest.raises(TypeMismatchError) as exc_info:
build_segment_with_type(SegmentType.ARRAY_STRING, {"key": "value"})
assert "Type mismatch" in str(exc_info.value)
assert "expected array[string]" in str(exc_info.value)
def test_compatible_number_types(self):
"""Test that int and float are both compatible with NUMBER type."""
# Integer should work
result_int = build_segment_with_type(SegmentType.NUMBER, 42)
assert isinstance(result_int, IntegerSegment)
assert result_int.value_type == SegmentType.NUMBER
# Float should work
result_float = build_segment_with_type(SegmentType.NUMBER, 3.14)
assert isinstance(result_float, FloatSegment)
assert result_float.value_type == SegmentType.NUMBER
@pytest.mark.parametrize(
("segment_type", "value", "expected_class"),
[
(SegmentType.STRING, "test", StringSegment),
(SegmentType.NUMBER, 42, IntegerSegment),
(SegmentType.NUMBER, 3.14, FloatSegment),
(SegmentType.OBJECT, {}, ObjectSegment),
(SegmentType.NONE, None, NoneSegment),
(SegmentType.ARRAY_STRING, [], ArrayStringSegment),
(SegmentType.ARRAY_NUMBER, [], ArrayNumberSegment),
(SegmentType.ARRAY_OBJECT, [], ArrayObjectSegment),
(SegmentType.ARRAY_ANY, [], ArrayAnySegment),
],
)
def test_parametrized_valid_types(self, segment_type, value, expected_class):
"""Parametrized test for valid type combinations."""
result = build_segment_with_type(segment_type, value)
assert isinstance(result, expected_class)
assert result.value == value
assert result.value_type == segment_type
@pytest.mark.parametrize(
("segment_type", "value"),
[
(SegmentType.STRING, 123),
(SegmentType.NUMBER, "not_a_number"),
(SegmentType.OBJECT, "not_an_object"),
(SegmentType.ARRAY_STRING, "not_an_array"),
(SegmentType.STRING, None),
(SegmentType.NUMBER, None),
],
)
def test_parametrized_type_mismatches(self, segment_type, value):
"""Parametrized test for type mismatches that should raise TypeMismatchError."""
with pytest.raises(TypeMismatchError):
build_segment_with_type(segment_type, value)
# Test cases for ValueError scenarios in build_segment function
class TestBuildSegmentValueErrors:
"""Test cases for ValueError scenarios in the build_segment function."""
@dataclass(frozen=True)
class ValueErrorTestCase:
"""Test case data for ValueError scenarios."""
name: str
description: str
test_value: Any
def _get_test_cases(self):
"""Get all test cases for ValueError scenarios."""
# Define inline classes for complex test cases
class CustomType:
pass
def unsupported_function():
return "test"
def gen():
yield 1
yield 2
return [
self.ValueErrorTestCase(
name="unsupported_custom_type",
description="custom class that doesn't match any supported type",
test_value=CustomType(),
),
self.ValueErrorTestCase(
name="unsupported_set_type",
description="set (unsupported collection type)",
test_value={1, 2, 3},
),
self.ValueErrorTestCase(
name="unsupported_tuple_type", description="tuple (unsupported type)", test_value=(1, 2, 3)
),
self.ValueErrorTestCase(
name="unsupported_bytes_type",
description="bytes (unsupported type)",
test_value=b"hello world",
),
self.ValueErrorTestCase(
name="unsupported_function_type",
description="function (unsupported type)",
test_value=unsupported_function,
),
self.ValueErrorTestCase(
name="unsupported_module_type", description="module (unsupported type)", test_value=math
),
self.ValueErrorTestCase(
name="array_with_unsupported_element_types",
description="array with unsupported element types",
test_value=[CustomType()],
),
self.ValueErrorTestCase(
name="mixed_array_with_unsupported_types",
description="array with mix of supported and unsupported types",
test_value=["valid_string", 42, CustomType()],
),
self.ValueErrorTestCase(
name="nested_unsupported_types",
description="nested structures containing unsupported types",
test_value=[{"valid": "data"}, CustomType()],
),
self.ValueErrorTestCase(
name="complex_number_type",
description="complex number (unsupported type)",
test_value=3 + 4j,
),
self.ValueErrorTestCase(
name="range_type", description="range object (unsupported type)", test_value=range(10)
),
self.ValueErrorTestCase(
name="generator_type",
description="generator (unsupported type)",
test_value=gen(),
),
self.ValueErrorTestCase(
name="exception_message_contains_value",
description="set to verify error message contains the actual unsupported value",
test_value={1, 2, 3},
),
self.ValueErrorTestCase(
name="array_with_mixed_unsupported_segment_types",
description="array processing with unsupported segment types in match",
test_value=[CustomType()],
),
self.ValueErrorTestCase(
name="frozenset_type",
description="frozenset (unsupported type)",
test_value=frozenset([1, 2, 3]),
),
self.ValueErrorTestCase(
name="memoryview_type",
description="memoryview (unsupported type)",
test_value=memoryview(b"hello"),
),
self.ValueErrorTestCase(
name="slice_type", description="slice object (unsupported type)", test_value=slice(1, 10, 2)
),
self.ValueErrorTestCase(name="type_object", description="type object (unsupported type)", test_value=type),
self.ValueErrorTestCase(
name="generic_object", description="generic object (unsupported type)", test_value=object()
),
]
def test_build_segment_unsupported_types(self):
"""Table-driven test for all ValueError scenarios in build_segment function."""
test_cases = self._get_test_cases()
for index, test_case in enumerate(test_cases, 1):
# Use test value directly
test_value = test_case.test_value
with pytest.raises(ValueError) as exc_info: # noqa: PT012
segment = variable_factory.build_segment(test_value)
pytest.fail(f"Test case {index} ({test_case.name}) should raise ValueError but not, result={segment}")
error_message = str(exc_info.value)
assert "not supported value" in error_message, (
f"Test case {index} ({test_case.name}): Expected 'not supported value' in error message, "
f"but got: {error_message}"
)
def test_build_segment_boolean_type_note(self):
"""Note: Boolean values are actually handled as integers in Python, so they don't raise ValueError."""
# Boolean values in Python are subclasses of int, so they get processed as integers
# True becomes IntegerSegment(value=1) and False becomes IntegerSegment(value=0)
true_segment = variable_factory.build_segment(True)
false_segment = variable_factory.build_segment(False)
# Verify they are processed as integers, not as errors
assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1"
assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0"
assert true_segment.value_type == SegmentType.NUMBER
assert false_segment.value_type == SegmentType.NUMBER

View File

@@ -0,0 +1,20 @@
import datetime
from libs.datetime_utils import naive_utc_now
def test_naive_utc_now(monkeypatch):
tz_aware_utc_now = datetime.datetime.now(tz=datetime.UTC)
def _now_func(tz: datetime.timezone | None) -> datetime.datetime:
return tz_aware_utc_now.astimezone(tz)
monkeypatch.setattr("libs.datetime_utils._now_func", _now_func)
naive_datetime = naive_utc_now()
assert naive_datetime.tzinfo is None
assert naive_datetime.date() == tz_aware_utc_now.date()
naive_time = naive_datetime.time()
utc_time = tz_aware_utc_now.time()
assert naive_time == utc_time

View File

View File

@@ -1,10 +1,15 @@
import dataclasses
import json
from unittest import mock
from uuid import uuid4
from constants import HIDDEN_VALUE
from core.file.enums import FileTransferMethod, FileType
from core.file.models import File
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
from models.workflow import Workflow, WorkflowNodeExecutionModel
from core.variables.segments import IntegerSegment, Segment
from factories.variable_factory import build_segment
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
def test_environment_variables():
@@ -163,3 +168,147 @@ class TestWorkflowNodeExecution:
original = {"a": 1, "b": ["2"]}
node_exec.execution_metadata = json.dumps(original)
assert node_exec.execution_metadata_dict == original
class TestIsSystemVariableEditable:
def test_is_system_variable(self):
cases = [
("query", True),
("files", True),
("dialogue_count", False),
("conversation_id", False),
("user_id", False),
("app_id", False),
("workflow_id", False),
("workflow_run_id", False),
]
for name, editable in cases:
assert editable == is_system_variable_editable(name)
assert is_system_variable_editable("invalid_or_new_system_variable") == False
class TestWorkflowDraftVariableGetValue:
def test_get_value_by_case(self):
@dataclasses.dataclass
class TestCase:
name: str
value: Segment
tenant_id = "test_tenant_id"
test_file = File(
tenant_id=tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/example.jpg",
filename="example.jpg",
extension=".jpg",
mime_type="image/jpeg",
size=100,
)
cases: list[TestCase] = [
TestCase(
name="number/int",
value=build_segment(1),
),
TestCase(
name="number/float",
value=build_segment(1.0),
),
TestCase(
name="string",
value=build_segment("a"),
),
TestCase(
name="object",
value=build_segment({}),
),
TestCase(
name="file",
value=build_segment(test_file),
),
TestCase(
name="array[any]",
value=build_segment([1, "a"]),
),
TestCase(
name="array[string]",
value=build_segment(["a", "b"]),
),
TestCase(
name="array[number]/int",
value=build_segment([1, 2]),
),
TestCase(
name="array[number]/float",
value=build_segment([1.0, 2.0]),
),
TestCase(
name="array[number]/mixed",
value=build_segment([1, 2.0]),
),
TestCase(
name="array[object]",
value=build_segment([{}, {"a": 1}]),
),
TestCase(
name="none",
value=build_segment(None),
),
]
for idx, c in enumerate(cases, 1):
fail_msg = f"test case {c.name} failed, index={idx}"
draft_var = WorkflowDraftVariable()
draft_var.set_value(c.value)
assert c.value == draft_var.get_value(), fail_msg
def test_file_variable_preserves_all_fields(self):
"""Test that File type variables preserve all fields during encoding/decoding."""
tenant_id = "test_tenant_id"
# Create a File with specific field values
test_file = File(
id="test_file_id",
tenant_id=tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/test.jpg",
filename="test.jpg",
extension=".jpg",
mime_type="image/jpeg",
size=12345, # Specific size to test preservation
storage_key="test_storage_key",
)
# Create a FileSegment and WorkflowDraftVariable
file_segment = build_segment(test_file)
draft_var = WorkflowDraftVariable()
draft_var.set_value(file_segment)
# Retrieve the value and verify all fields are preserved
retrieved_segment = draft_var.get_value()
retrieved_file = retrieved_segment.value
# Verify all important fields are preserved
assert retrieved_file.id == test_file.id
assert retrieved_file.tenant_id == test_file.tenant_id
assert retrieved_file.type == test_file.type
assert retrieved_file.transfer_method == test_file.transfer_method
assert retrieved_file.remote_url == test_file.remote_url
assert retrieved_file.filename == test_file.filename
assert retrieved_file.extension == test_file.extension
assert retrieved_file.mime_type == test_file.mime_type
assert retrieved_file.size == test_file.size # This was the main issue being fixed
# Note: storage_key is not serialized in model_dump() so it won't be preserved
# Verify the segments have the same type and the important fields match
assert file_segment.value_type == retrieved_segment.value_type
def test_get_and_set_value(self):
draft_var = WorkflowDraftVariable()
int_var = IntegerSegment(value=1)
draft_var.set_value(int_var)
value = draft_var.get_value()
assert value == int_var

View File

@@ -0,0 +1,222 @@
import dataclasses
import secrets
from unittest import mock
from unittest.mock import Mock, patch
import pytest
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables.types import SegmentType
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes import NodeType
from models.enums import DraftVariableType
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel
from services.workflow_draft_variable_service import (
DraftVariableSaver,
VariableResetError,
WorkflowDraftVariableService,
)
class TestDraftVariableSaver:
def _get_test_app_id(self):
suffix = secrets.token_hex(6)
return f"test_app_id_{suffix}"
def test__should_variable_be_visible(self):
mock_session = mock.MagicMock(spec=Session)
test_app_id = self._get_test_app_id()
saver = DraftVariableSaver(
session=mock_session,
app_id=test_app_id,
node_id="test_node_id",
node_type=NodeType.START,
invoke_from=InvokeFrom.DEBUGGER,
node_execution_id="test_execution_id",
)
assert saver._should_variable_be_visible("123_456", NodeType.IF_ELSE, "output") == False
assert saver._should_variable_be_visible("123", NodeType.START, "output") == True
def test__normalize_variable_for_start_node(self):
@dataclasses.dataclass(frozen=True)
class TestCase:
name: str
input_node_id: str
input_name: str
expected_node_id: str
expected_name: str
_NODE_ID = "1747228642872"
cases = [
TestCase(
name="name with `sys.` prefix should return the system node_id",
input_node_id=_NODE_ID,
input_name="sys.workflow_id",
expected_node_id=SYSTEM_VARIABLE_NODE_ID,
expected_name="workflow_id",
),
TestCase(
name="name without `sys.` prefix should return the original input node_id",
input_node_id=_NODE_ID,
input_name="start_input",
expected_node_id=_NODE_ID,
expected_name="start_input",
),
TestCase(
name="dummy_variable should return the original input node_id",
input_node_id=_NODE_ID,
input_name="__dummy__",
expected_node_id=_NODE_ID,
expected_name="__dummy__",
),
]
mock_session = mock.MagicMock(spec=Session)
test_app_id = self._get_test_app_id()
saver = DraftVariableSaver(
session=mock_session,
app_id=test_app_id,
node_id=_NODE_ID,
node_type=NodeType.START,
invoke_from=InvokeFrom.DEBUGGER,
node_execution_id="test_execution_id",
)
for idx, c in enumerate(cases, 1):
fail_msg = f"Test case {c.name} failed, index={idx}"
node_id, name = saver._normalize_variable_for_start_node(c.input_name)
assert node_id == c.expected_node_id, fail_msg
assert name == c.expected_name, fail_msg
class TestWorkflowDraftVariableService:
def _get_test_app_id(self):
suffix = secrets.token_hex(6)
return f"test_app_id_{suffix}"
def test_reset_conversation_variable(self):
"""Test resetting a conversation variable"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
mock_workflow = Mock(spec=Workflow)
mock_workflow.app_id = self._get_test_app_id()
# Create mock variable
mock_variable = Mock(spec=WorkflowDraftVariable)
mock_variable.get_variable_type.return_value = DraftVariableType.CONVERSATION
mock_variable.id = "var-id"
mock_variable.name = "test_var"
# Mock the _reset_conv_var method
expected_result = Mock(spec=WorkflowDraftVariable)
with patch.object(service, "_reset_conv_var", return_value=expected_result) as mock_reset_conv:
result = service.reset_variable(mock_workflow, mock_variable)
mock_reset_conv.assert_called_once_with(mock_workflow, mock_variable)
assert result == expected_result
def test_reset_node_variable_with_no_execution_id(self):
"""Test resetting a node variable with no execution ID - should delete variable"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
mock_workflow = Mock(spec=Workflow)
mock_workflow.app_id = self._get_test_app_id()
# Create mock variable with no execution ID
mock_variable = Mock(spec=WorkflowDraftVariable)
mock_variable.get_variable_type.return_value = DraftVariableType.NODE
mock_variable.node_execution_id = None
mock_variable.id = "var-id"
mock_variable.name = "test_var"
result = service._reset_node_var(mock_workflow, mock_variable)
# Should delete the variable and return None
mock_session.delete.assert_called_once_with(instance=mock_variable)
mock_session.flush.assert_called_once()
assert result is None
def test_reset_node_variable_with_missing_execution_record(self):
"""Test resetting a node variable when execution record doesn't exist"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
mock_workflow = Mock(spec=Workflow)
mock_workflow.app_id = self._get_test_app_id()
# Create mock variable with execution ID
mock_variable = Mock(spec=WorkflowDraftVariable)
mock_variable.get_variable_type.return_value = DraftVariableType.NODE
mock_variable.node_execution_id = "exec-id"
mock_variable.id = "var-id"
mock_variable.name = "test_var"
# Mock session.scalars to return None (no execution record found)
mock_scalars = Mock()
mock_scalars.first.return_value = None
mock_session.scalars.return_value = mock_scalars
result = service._reset_node_var(mock_workflow, mock_variable)
# Should delete the variable and return None
mock_session.delete.assert_called_once_with(instance=mock_variable)
mock_session.flush.assert_called_once()
assert result is None
def test_reset_node_variable_with_valid_execution_record(self):
"""Test resetting a node variable with valid execution record - should restore from execution"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
mock_workflow = Mock(spec=Workflow)
mock_workflow.app_id = self._get_test_app_id()
# Create mock variable with execution ID
mock_variable = Mock(spec=WorkflowDraftVariable)
mock_variable.get_variable_type.return_value = DraftVariableType.NODE
mock_variable.node_execution_id = "exec-id"
mock_variable.id = "var-id"
mock_variable.name = "test_var"
mock_variable.node_id = "node-id"
mock_variable.value_type = SegmentType.STRING
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.process_data_dict = {"test_var": "process_value"}
mock_execution.outputs_dict = {"test_var": "output_value"}
# Mock session.scalars to return the execution record
mock_scalars = Mock()
mock_scalars.first.return_value = mock_execution
mock_session.scalars.return_value = mock_scalars
# Mock workflow methods
mock_node_config = {"type": "test_node"}
mock_workflow.get_node_config_by_id.return_value = mock_node_config
mock_workflow.get_node_type_from_node_config.return_value = NodeType.LLM
result = service._reset_node_var(mock_workflow, mock_variable)
# Verify variable.set_value was called with the correct value
mock_variable.set_value.assert_called_once()
# Verify last_edited_at was reset
assert mock_variable.last_edited_at is None
# Verify session.flush was called
mock_session.flush.assert_called()
# Should return the updated variable
assert result == mock_variable
def test_reset_system_variable_raises_error(self):
"""Test that resetting a system variable raises an error"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
mock_workflow = Mock(spec=Workflow)
mock_workflow.app_id = self._get_test_app_id()
mock_variable = Mock(spec=WorkflowDraftVariable)
mock_variable.get_variable_type.return_value = DraftVariableType.SYS # Not a valid enum value for this test
mock_variable.id = "var-id"
with pytest.raises(VariableResetError) as exc_info:
service.reset_variable(mock_workflow, mock_variable)
assert "cannot reset system variable" in str(exc_info.value)
assert "variable_id=var-id" in str(exc_info.value)