feat: Persist Variables for Enhanced Debugging Workflow (#20699)
This pull request introduces a feature aimed at improving the debugging experience during workflow editing. With the addition of variable persistence, the system will automatically retain the output variables from previously executed nodes. These persisted variables can then be reused when debugging subsequent nodes, eliminating the need for repetitive manual input. By streamlining this aspect of the workflow, the feature minimizes user errors and significantly reduces debugging effort, offering a smoother and more efficient experience. Key highlights of this change: - Automatic persistence of output variables for executed nodes. - Reuse of persisted variables to simplify input steps for nodes requiring them (e.g., `code`, `template`, `variable_assigner`). - Enhanced debugging experience with reduced friction. Closes #19735.
This commit is contained in:
@@ -1,107 +1,217 @@
|
||||
# OpenAI API Key
|
||||
OPENAI_API_KEY=
|
||||
FLASK_APP=app.py
|
||||
FLASK_DEBUG=0
|
||||
SECRET_KEY='uhySf6a3aZuvRNfAlcr47paOw9TRYBY6j8ZHXpVw1yx5RP27Yj3w2uvI'
|
||||
|
||||
# Azure OpenAI API Base Endpoint & API Key
|
||||
AZURE_OPENAI_API_BASE=
|
||||
AZURE_OPENAI_API_KEY=
|
||||
CONSOLE_API_URL=http://127.0.0.1:5001
|
||||
CONSOLE_WEB_URL=http://127.0.0.1:3000
|
||||
|
||||
# Anthropic API Key
|
||||
ANTHROPIC_API_KEY=
|
||||
# Service API base URL
|
||||
SERVICE_API_URL=http://127.0.0.1:5001
|
||||
|
||||
# Replicate API Key
|
||||
REPLICATE_API_KEY=
|
||||
# Web APP base URL
|
||||
APP_WEB_URL=http://127.0.0.1:3000
|
||||
|
||||
# Hugging Face API Key
|
||||
HUGGINGFACE_API_KEY=
|
||||
HUGGINGFACE_TEXT_GEN_ENDPOINT_URL=
|
||||
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL=
|
||||
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL=
|
||||
# Files URL
|
||||
FILES_URL=http://127.0.0.1:5001
|
||||
|
||||
# Minimax Credentials
|
||||
MINIMAX_API_KEY=
|
||||
MINIMAX_GROUP_ID=
|
||||
# The time in seconds after the signature is rejected
|
||||
FILES_ACCESS_TIMEOUT=300
|
||||
|
||||
# Spark Credentials
|
||||
SPARK_APP_ID=
|
||||
SPARK_API_KEY=
|
||||
SPARK_API_SECRET=
|
||||
# Access token expiration time in minutes
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
|
||||
# Tongyi Credentials
|
||||
TONGYI_DASHSCOPE_API_KEY=
|
||||
# Refresh token expiration time in days
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||
|
||||
# Wenxin Credentials
|
||||
WENXIN_API_KEY=
|
||||
WENXIN_SECRET_KEY=
|
||||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
||||
|
||||
# ZhipuAI Credentials
|
||||
ZHIPUAI_API_KEY=
|
||||
# redis configuration
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_USERNAME=
|
||||
REDIS_PASSWORD=difyai123456
|
||||
REDIS_USE_SSL=false
|
||||
REDIS_DB=0
|
||||
|
||||
# Baichuan Credentials
|
||||
BAICHUAN_API_KEY=
|
||||
BAICHUAN_SECRET_KEY=
|
||||
# PostgreSQL database configuration
|
||||
DB_USERNAME=postgres
|
||||
DB_PASSWORD=difyai123456
|
||||
DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_DATABASE=dify
|
||||
|
||||
# ChatGLM Credentials
|
||||
CHATGLM_API_BASE=
|
||||
# Storage configuration
|
||||
# use for store upload files, private keys...
|
||||
# storage type: opendal, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase
|
||||
STORAGE_TYPE=opendal
|
||||
|
||||
# Xinference Credentials
|
||||
XINFERENCE_SERVER_URL=
|
||||
XINFERENCE_GENERATION_MODEL_UID=
|
||||
XINFERENCE_CHAT_MODEL_UID=
|
||||
XINFERENCE_EMBEDDINGS_MODEL_UID=
|
||||
XINFERENCE_RERANK_MODEL_UID=
|
||||
# Apache OpenDAL storage configuration, refer to https://github.com/apache/opendal
|
||||
OPENDAL_SCHEME=fs
|
||||
OPENDAL_FS_ROOT=storage
|
||||
|
||||
# OpenLLM Credentials
|
||||
OPENLLM_SERVER_URL=
|
||||
# CORS configuration
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# LocalAI Credentials
|
||||
LOCALAI_SERVER_URL=
|
||||
# Vector database configuration
|
||||
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase
|
||||
VECTOR_STORE=weaviate
|
||||
# Weaviate configuration
|
||||
WEAVIATE_ENDPOINT=http://localhost:8080
|
||||
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
||||
WEAVIATE_GRPC_ENABLED=false
|
||||
WEAVIATE_BATCH_SIZE=100
|
||||
|
||||
# Cohere Credentials
|
||||
COHERE_API_KEY=
|
||||
|
||||
# Jina Credentials
|
||||
JINA_API_KEY=
|
||||
# Upload configuration
|
||||
UPLOAD_FILE_SIZE_LIMIT=15
|
||||
UPLOAD_FILE_BATCH_LIMIT=5
|
||||
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
|
||||
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||
|
||||
# Ollama Credentials
|
||||
OLLAMA_BASE_URL=
|
||||
# Model configuration
|
||||
MULTIMODAL_SEND_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=4096
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
|
||||
# Together API Key
|
||||
TOGETHER_API_KEY=
|
||||
# Mail configuration, support: resend, smtp
|
||||
MAIL_TYPE=
|
||||
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@example.com>
|
||||
RESEND_API_KEY=
|
||||
RESEND_API_URL=https://api.resend.com
|
||||
# smtp configuration
|
||||
SMTP_SERVER=smtp.example.com
|
||||
SMTP_PORT=465
|
||||
SMTP_USERNAME=123
|
||||
SMTP_PASSWORD=abc
|
||||
SMTP_USE_TLS=true
|
||||
SMTP_OPPORTUNISTIC_TLS=false
|
||||
|
||||
# Mock Switch
|
||||
MOCK_SWITCH=false
|
||||
# Sentry configuration
|
||||
SENTRY_DSN=
|
||||
|
||||
# DEBUG
|
||||
DEBUG=false
|
||||
SQLALCHEMY_ECHO=false
|
||||
|
||||
# Notion import configuration, support public and internal
|
||||
NOTION_INTEGRATION_TYPE=public
|
||||
NOTION_CLIENT_SECRET=you-client-secret
|
||||
NOTION_CLIENT_ID=you-client-id
|
||||
NOTION_INTERNAL_SECRET=you-internal-secret
|
||||
|
||||
ETL_TYPE=dify
|
||||
UNSTRUCTURED_API_URL=
|
||||
UNSTRUCTURED_API_KEY=
|
||||
SCARF_NO_ANALYTICS=false
|
||||
|
||||
#ssrf
|
||||
SSRF_PROXY_HTTP_URL=
|
||||
SSRF_PROXY_HTTPS_URL=
|
||||
SSRF_DEFAULT_MAX_RETRIES=3
|
||||
SSRF_DEFAULT_TIME_OUT=5
|
||||
SSRF_DEFAULT_CONNECT_TIME_OUT=5
|
||||
SSRF_DEFAULT_READ_TIME_OUT=5
|
||||
SSRF_DEFAULT_WRITE_TIME_OUT=5
|
||||
|
||||
BATCH_UPLOAD_LIMIT=10
|
||||
KEYWORD_DATA_SOURCE_TYPE=database
|
||||
|
||||
# Workflow file upload limit
|
||||
WORKFLOW_FILE_UPLOAD_LIMIT=10
|
||||
|
||||
# CODE EXECUTION CONFIGURATION
|
||||
CODE_EXECUTION_ENDPOINT=
|
||||
CODE_EXECUTION_API_KEY=
|
||||
CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194
|
||||
CODE_EXECUTION_API_KEY=dify-sandbox
|
||||
CODE_MAX_NUMBER=9223372036854775807
|
||||
CODE_MIN_NUMBER=-9223372036854775808
|
||||
CODE_MAX_STRING_LENGTH=80000
|
||||
TEMPLATE_TRANSFORM_MAX_LENGTH=80000
|
||||
CODE_MAX_STRING_ARRAY_LENGTH=30
|
||||
CODE_MAX_OBJECT_ARRAY_LENGTH=30
|
||||
CODE_MAX_NUMBER_ARRAY_LENGTH=1000
|
||||
|
||||
# Volcengine MaaS Credentials
|
||||
VOLC_API_KEY=
|
||||
VOLC_SECRET_KEY=
|
||||
VOLC_MODEL_ENDPOINT_ID=
|
||||
VOLC_EMBEDDING_ENDPOINT_ID=
|
||||
# API Tool configuration
|
||||
API_TOOL_DEFAULT_CONNECT_TIMEOUT=10
|
||||
API_TOOL_DEFAULT_READ_TIMEOUT=60
|
||||
|
||||
# 360 AI Credentials
|
||||
ZHINAO_API_KEY=
|
||||
# HTTP Node configuration
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT=300
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT=600
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT=600
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
|
||||
|
||||
# Respect X-* headers to redirect clients
|
||||
RESPECT_XFORWARD_HEADERS_ENABLED=false
|
||||
|
||||
# Log file path
|
||||
LOG_FILE=
|
||||
# Log file max size, the unit is MB
|
||||
LOG_FILE_MAX_SIZE=20
|
||||
# Log file max backup count
|
||||
LOG_FILE_BACKUP_COUNT=5
|
||||
# Log dateformat
|
||||
LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
|
||||
# Log Timezone
|
||||
LOG_TZ=UTC
|
||||
# Log format
|
||||
LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s
|
||||
|
||||
# Indexing configuration
|
||||
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
|
||||
|
||||
# Workflow runtime configuration
|
||||
WORKFLOW_MAX_EXECUTION_STEPS=500
|
||||
WORKFLOW_MAX_EXECUTION_TIME=1200
|
||||
WORKFLOW_CALL_MAX_DEPTH=5
|
||||
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
||||
MAX_VARIABLE_SIZE=204800
|
||||
|
||||
# App configuration
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
APP_MAX_ACTIVE_REQUESTS=0
|
||||
|
||||
# Celery beat configuration
|
||||
CELERY_BEAT_SCHEDULER_TIME=1
|
||||
|
||||
# Position configuration
|
||||
POSITION_TOOL_PINS=
|
||||
POSITION_TOOL_INCLUDES=
|
||||
POSITION_TOOL_EXCLUDES=
|
||||
|
||||
POSITION_PROVIDER_PINS=
|
||||
POSITION_PROVIDER_INCLUDES=
|
||||
POSITION_PROVIDER_EXCLUDES=
|
||||
|
||||
# Plugin configuration
|
||||
PLUGIN_DAEMON_KEY=
|
||||
PLUGIN_DAEMON_URL=
|
||||
PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi
|
||||
PLUGIN_DAEMON_URL=http://127.0.0.1:5002
|
||||
PLUGIN_REMOTE_INSTALL_PORT=5003
|
||||
PLUGIN_REMOTE_INSTALL_HOST=localhost
|
||||
PLUGIN_MAX_PACKAGE_SIZE=15728640
|
||||
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
|
||||
|
||||
# Marketplace configuration
|
||||
MARKETPLACE_API_URL=
|
||||
# VESSL AI Credentials
|
||||
VESSL_AI_MODEL_NAME=
|
||||
VESSL_AI_API_KEY=
|
||||
VESSL_AI_ENDPOINT_URL=
|
||||
MARKETPLACE_ENABLED=true
|
||||
MARKETPLACE_API_URL=https://marketplace.dify.ai
|
||||
|
||||
# GPUStack Credentials
|
||||
GPUSTACK_SERVER_URL=
|
||||
GPUSTACK_API_KEY=
|
||||
# Endpoint configuration
|
||||
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
|
||||
|
||||
# Gitee AI Credentials
|
||||
GITEE_AI_API_KEY=
|
||||
# Reset password token expiry minutes
|
||||
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
|
||||
|
||||
# xAI Credentials
|
||||
XAI_API_KEY=
|
||||
XAI_API_BASE=
|
||||
CREATE_TIDB_SERVICE_JOB_ENABLED=false
|
||||
|
||||
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
|
||||
MAX_SUBMIT_COUNT=100
|
||||
# Lockout duration in seconds
|
||||
LOGIN_LOCKOUT_DURATION=86400
|
||||
|
||||
HTTP_PROXY='http://127.0.0.1:1092'
|
||||
HTTPS_PROXY='http://127.0.0.1:1092'
|
||||
NO_PROXY='localhost,127.0.0.1'
|
||||
LOG_LEVEL=INFO
|
||||
|
@@ -1,19 +1,91 @@
|
||||
import os
|
||||
import pathlib
|
||||
import random
|
||||
import secrets
|
||||
from collections.abc import Generator
|
||||
|
||||
# Getting the absolute path of the current file's directory
|
||||
ABS_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Getting the absolute path of the project's root directory
|
||||
PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir))
|
||||
from app_factory import create_app
|
||||
from models import Account, DifySetup, Tenant, TenantAccountJoin, db
|
||||
from services.account_service import AccountService, RegisterService
|
||||
|
||||
|
||||
# Loading the .env file if it exists
|
||||
def _load_env() -> None:
|
||||
dotenv_path = os.path.join(PROJECT_DIR, "tests", "integration_tests", ".env")
|
||||
if os.path.exists(dotenv_path):
|
||||
current_file_path = pathlib.Path(__file__).absolute()
|
||||
# Items later in the list have higher precedence.
|
||||
files_to_load = [".env", "vdb.env"]
|
||||
|
||||
env_file_paths = [current_file_path.parent / i for i in files_to_load]
|
||||
for path in env_file_paths:
|
||||
if not path.exists():
|
||||
continue
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(dotenv_path)
|
||||
# Set `override=True` to ensure values from `vdb.env` take priority over values from `.env`
|
||||
load_dotenv(str(path), override=True)
|
||||
|
||||
|
||||
_load_env()
|
||||
|
||||
_CACHED_APP = create_app()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app() -> Flask:
|
||||
return _CACHED_APP
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def setup_account(request) -> Generator[Account, None, None]:
|
||||
"""`dify_setup` completes the setup process for the Dify application.
|
||||
|
||||
It creates `Account` and `Tenant`, and inserts a `DifySetup` record into the database.
|
||||
|
||||
Most tests in the `controllers` package may require dify has been successfully setup.
|
||||
"""
|
||||
with _CACHED_APP.test_request_context():
|
||||
rand_suffix = random.randint(int(1e6), int(1e7)) # noqa
|
||||
name = f"test-user-{rand_suffix}"
|
||||
email = f"{name}@example.com"
|
||||
RegisterService.setup(
|
||||
email=email,
|
||||
name=name,
|
||||
password=secrets.token_hex(16),
|
||||
ip_address="localhost",
|
||||
)
|
||||
|
||||
with _CACHED_APP.test_request_context():
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
account = session.query(Account).filter_by(email=email).one()
|
||||
|
||||
yield account
|
||||
|
||||
with _CACHED_APP.test_request_context():
|
||||
db.session.query(DifySetup).delete()
|
||||
db.session.query(TenantAccountJoin).delete()
|
||||
db.session.query(Account).delete()
|
||||
db.session.query(Tenant).delete()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_req_ctx():
|
||||
with _CACHED_APP.test_request_context():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_header(setup_account) -> dict[str, str]:
|
||||
token = AccountService.get_account_jwt_token(setup_account)
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client() -> Generator[FlaskClient, None, None]:
|
||||
with _CACHED_APP.test_client() as client:
|
||||
yield client
|
||||
|
@@ -1,25 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from app_factory import create_app
|
||||
from configs import dify_config
|
||||
|
||||
mock_user = type(
|
||||
"MockUser",
|
||||
(object,),
|
||||
{
|
||||
"is_authenticated": True,
|
||||
"id": "123",
|
||||
"is_editor": True,
|
||||
"is_dataset_editor": True,
|
||||
"status": "active",
|
||||
"get_id": "123",
|
||||
"current_tenant_id": "9d2074fc-6f86-45a9-b09d-6ecc63b9056b",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = create_app()
|
||||
dify_config.LOGIN_DISABLED = True
|
||||
return app
|
@@ -0,0 +1,47 @@
|
||||
import uuid
|
||||
from unittest import mock
|
||||
|
||||
from controllers.console.app import workflow_draft_variable as draft_variable_api
|
||||
from controllers.console.app import wraps
|
||||
from factories.variable_factory import build_segment
|
||||
from models import App, AppMode
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
|
||||
|
||||
def _get_mock_srv_class() -> type[WorkflowDraftVariableService]:
|
||||
return mock.create_autospec(WorkflowDraftVariableService)
|
||||
|
||||
|
||||
class TestWorkflowDraftNodeVariableListApi:
|
||||
def test_get(self, test_client, auth_header, monkeypatch):
|
||||
srv_class = _get_mock_srv_class()
|
||||
mock_app_model: App = App()
|
||||
mock_app_model.id = str(uuid.uuid4())
|
||||
test_node_id = "test_node_id"
|
||||
mock_app_model.mode = AppMode.ADVANCED_CHAT
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
|
||||
monkeypatch.setattr(draft_variable_api, "WorkflowDraftVariableService", srv_class)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
var1 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id="test_app_1",
|
||||
node_id="test_node_1",
|
||||
name="str_var",
|
||||
value=build_segment("str_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
srv_instance = mock.create_autospec(WorkflowDraftVariableService, instance=True)
|
||||
srv_class.return_value = srv_instance
|
||||
srv_instance.list_node_variables.return_value = WorkflowDraftVariableList(variables=[var1])
|
||||
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{mock_app_model.id}/workflows/draft/nodes/{test_node_id}/variables",
|
||||
headers=auth_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
response_dict = response.json
|
||||
assert isinstance(response_dict, dict)
|
||||
assert "items" in response_dict
|
||||
assert len(response_dict["items"]) == 1
|
@@ -1,9 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from app_fixture import mock_user # type: ignore
|
||||
|
||||
|
||||
def test_post_requires_login(app):
|
||||
with app.test_client() as client, patch("flask_login.utils._get_user", mock_user):
|
||||
response = client.get("/console/api/data-source/integrates")
|
||||
assert response.status_code == 200
|
0
api/tests/integration_tests/factories/__init__.py
Normal file
0
api/tests/integration_tests/factories/__init__.py
Normal file
371
api/tests/integration_tests/factories/test_storage_key_loader.py
Normal file
371
api/tests/integration_tests/factories/test_storage_key_loader.py
Normal file
@@ -0,0 +1,371 @@
|
||||
import unittest
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import StorageKeyLoader
|
||||
from models import ToolFile, UploadFile
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx")
|
||||
class TestStorageKeyLoader(unittest.TestCase):
|
||||
"""
|
||||
Integration tests for StorageKeyLoader class.
|
||||
|
||||
Tests the batched loading of storage keys from the database for files
|
||||
with different transfer methods: LOCAL_FILE, REMOTE_URL, and TOOL_FILE.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data before each test method."""
|
||||
self.session = db.session()
|
||||
self.tenant_id = str(uuid4())
|
||||
self.user_id = str(uuid4())
|
||||
self.conversation_id = str(uuid4())
|
||||
|
||||
# Create test data that will be cleaned up after each test
|
||||
self.test_upload_files = []
|
||||
self.test_tool_files = []
|
||||
|
||||
# Create StorageKeyLoader instance
|
||||
self.loader = StorageKeyLoader(self.session, self.tenant_id)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test data after each test method."""
|
||||
self.session.rollback()
|
||||
|
||||
def _create_upload_file(
|
||||
self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None
|
||||
) -> UploadFile:
|
||||
"""Helper method to create an UploadFile record for testing."""
|
||||
if file_id is None:
|
||||
file_id = str(uuid4())
|
||||
if storage_key is None:
|
||||
storage_key = f"test_storage_key_{uuid4()}"
|
||||
if tenant_id is None:
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
key=storage_key,
|
||||
name="test_file.txt",
|
||||
size=1024,
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=self.user_id,
|
||||
created_at=datetime.now(UTC),
|
||||
used=False,
|
||||
)
|
||||
upload_file.id = file_id
|
||||
|
||||
self.session.add(upload_file)
|
||||
self.session.flush()
|
||||
self.test_upload_files.append(upload_file)
|
||||
|
||||
return upload_file
|
||||
|
||||
def _create_tool_file(
|
||||
self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None
|
||||
) -> ToolFile:
|
||||
"""Helper method to create a ToolFile record for testing."""
|
||||
if file_id is None:
|
||||
file_id = str(uuid4())
|
||||
if file_key is None:
|
||||
file_key = f"test_file_key_{uuid4()}"
|
||||
if tenant_id is None:
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
tool_file = ToolFile()
|
||||
tool_file.id = file_id
|
||||
tool_file.user_id = self.user_id
|
||||
tool_file.tenant_id = tenant_id
|
||||
tool_file.conversation_id = self.conversation_id
|
||||
tool_file.file_key = file_key
|
||||
tool_file.mimetype = "text/plain"
|
||||
tool_file.original_url = "http://example.com/file.txt"
|
||||
tool_file.name = "test_tool_file.txt"
|
||||
tool_file.size = 2048
|
||||
|
||||
self.session.add(tool_file)
|
||||
self.session.flush()
|
||||
self.test_tool_files.append(tool_file)
|
||||
|
||||
return tool_file
|
||||
|
||||
def _create_file(
|
||||
self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None
|
||||
) -> File:
|
||||
"""Helper method to create a File object for testing."""
|
||||
if tenant_id is None:
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
# Set related_id for LOCAL_FILE and TOOL_FILE transfer methods
|
||||
file_related_id = None
|
||||
remote_url = None
|
||||
|
||||
if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE):
|
||||
file_related_id = related_id
|
||||
elif transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
remote_url = "https://example.com/test_file.txt"
|
||||
file_related_id = related_id
|
||||
|
||||
return File(
|
||||
id=str(uuid4()), # Generate new UUID for File.id
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=transfer_method,
|
||||
related_id=file_related_id,
|
||||
remote_url=remote_url,
|
||||
filename="test_file.txt",
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
size=1024,
|
||||
storage_key="initial_key",
|
||||
)
|
||||
|
||||
def test_load_storage_keys_local_file(self):
|
||||
"""Test loading storage keys for LOCAL_FILE transfer method."""
|
||||
# Create test data
|
||||
upload_file = self._create_upload_file()
|
||||
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
|
||||
# Load storage keys
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
# Verify storage key was loaded correctly
|
||||
assert file._storage_key == upload_file.key
|
||||
|
||||
def test_load_storage_keys_remote_url(self):
|
||||
"""Test loading storage keys for REMOTE_URL transfer method."""
|
||||
# Create test data
|
||||
upload_file = self._create_upload_file()
|
||||
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL)
|
||||
|
||||
# Load storage keys
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
# Verify storage key was loaded correctly
|
||||
assert file._storage_key == upload_file.key
|
||||
|
||||
def test_load_storage_keys_tool_file(self):
|
||||
"""Test loading storage keys for TOOL_FILE transfer method."""
|
||||
# Create test data
|
||||
tool_file = self._create_tool_file()
|
||||
file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
|
||||
|
||||
# Load storage keys
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
# Verify storage key was loaded correctly
|
||||
assert file._storage_key == tool_file.file_key
|
||||
|
||||
def test_load_storage_keys_mixed_methods(self):
|
||||
"""Test batch loading with mixed transfer methods."""
|
||||
# Create test data for different transfer methods
|
||||
upload_file1 = self._create_upload_file()
|
||||
upload_file2 = self._create_upload_file()
|
||||
tool_file = self._create_tool_file()
|
||||
|
||||
file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL)
|
||||
file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
|
||||
|
||||
files = [file1, file2, file3]
|
||||
|
||||
# Load storage keys
|
||||
self.loader.load_storage_keys(files)
|
||||
|
||||
# Verify all storage keys were loaded correctly
|
||||
assert file1._storage_key == upload_file1.key
|
||||
assert file2._storage_key == upload_file2.key
|
||||
assert file3._storage_key == tool_file.file_key
|
||||
|
||||
def test_load_storage_keys_empty_list(self):
|
||||
"""Test with empty file list."""
|
||||
# Should not raise any exceptions
|
||||
self.loader.load_storage_keys([])
|
||||
|
||||
def test_load_storage_keys_tenant_mismatch(self):
|
||||
"""Test tenant_id validation."""
|
||||
# Create file with different tenant_id
|
||||
upload_file = self._create_upload_file()
|
||||
file = self._create_file(
|
||||
related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4())
|
||||
)
|
||||
|
||||
# Should raise ValueError for tenant mismatch
|
||||
with pytest.raises(ValueError) as context:
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
assert "invalid file, expected tenant_id" in str(context.value)
|
||||
|
||||
def test_load_storage_keys_missing_file_id(self):
|
||||
"""Test with None file.related_id."""
|
||||
# Create a file with valid parameters first, then manually set related_id to None
|
||||
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
file.related_id = None
|
||||
|
||||
# Should raise ValueError for None file related_id
|
||||
with pytest.raises(ValueError) as context:
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
assert str(context.value) == "file id should not be None."
|
||||
|
||||
def test_load_storage_keys_nonexistent_upload_file_records(self):
|
||||
"""Test with missing UploadFile database records."""
|
||||
# Create file with non-existent upload file id
|
||||
non_existent_id = str(uuid4())
|
||||
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
|
||||
# Should raise ValueError for missing record
|
||||
with pytest.raises(ValueError):
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
def test_load_storage_keys_nonexistent_tool_file_records(self):
|
||||
"""Test with missing ToolFile database records."""
|
||||
# Create file with non-existent tool file id
|
||||
non_existent_id = str(uuid4())
|
||||
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE)
|
||||
|
||||
# Should raise ValueError for missing record
|
||||
with pytest.raises(ValueError):
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
def test_load_storage_keys_invalid_uuid(self):
|
||||
"""Test with invalid UUID format."""
|
||||
# Create a file with valid parameters first, then manually set invalid related_id
|
||||
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
file.related_id = "invalid-uuid-format"
|
||||
|
||||
# Should raise ValueError for invalid UUID
|
||||
with pytest.raises(ValueError):
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
def test_load_storage_keys_batch_efficiency(self):
|
||||
"""Test batched operations use efficient queries."""
|
||||
# Create multiple files of different types
|
||||
upload_files = [self._create_upload_file() for _ in range(3)]
|
||||
tool_files = [self._create_tool_file() for _ in range(2)]
|
||||
|
||||
files = []
|
||||
files.extend(
|
||||
[self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files]
|
||||
)
|
||||
files.extend(
|
||||
[self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files]
|
||||
)
|
||||
|
||||
# Mock the session to count queries
|
||||
with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars:
|
||||
self.loader.load_storage_keys(files)
|
||||
|
||||
# Should make exactly 2 queries (one for upload_files, one for tool_files)
|
||||
assert mock_scalars.call_count == 2
|
||||
|
||||
# Verify all storage keys were loaded correctly
|
||||
for i, file in enumerate(files[:3]):
|
||||
assert file._storage_key == upload_files[i].key
|
||||
for i, file in enumerate(files[3:]):
|
||||
assert file._storage_key == tool_files[i].file_key
|
||||
|
||||
def test_load_storage_keys_tenant_isolation(self):
|
||||
"""Test that tenant isolation works correctly."""
|
||||
# Create files for different tenants
|
||||
other_tenant_id = str(uuid4())
|
||||
|
||||
# Create upload file for current tenant
|
||||
upload_file_current = self._create_upload_file()
|
||||
file_current = self._create_file(
|
||||
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
|
||||
)
|
||||
|
||||
# Create upload file for other tenant (but don't add to cleanup list)
|
||||
upload_file_other = UploadFile(
|
||||
tenant_id=other_tenant_id,
|
||||
storage_type="local",
|
||||
key="other_tenant_key",
|
||||
name="other_file.txt",
|
||||
size=1024,
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=self.user_id,
|
||||
created_at=datetime.now(UTC),
|
||||
used=False,
|
||||
)
|
||||
upload_file_other.id = str(uuid4())
|
||||
self.session.add(upload_file_other)
|
||||
self.session.flush()
|
||||
|
||||
# Create file for other tenant but try to load with current tenant's loader
|
||||
file_other = self._create_file(
|
||||
related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
|
||||
)
|
||||
|
||||
# Should raise ValueError due to tenant mismatch
|
||||
with pytest.raises(ValueError) as context:
|
||||
self.loader.load_storage_keys([file_other])
|
||||
|
||||
assert "invalid file, expected tenant_id" in str(context.value)
|
||||
|
||||
# Current tenant's file should still work
|
||||
self.loader.load_storage_keys([file_current])
|
||||
assert file_current._storage_key == upload_file_current.key
|
||||
|
||||
def test_load_storage_keys_mixed_tenant_batch(self):
|
||||
"""Test batch with mixed tenant files (should fail on first mismatch)."""
|
||||
# Create files for current tenant
|
||||
upload_file_current = self._create_upload_file()
|
||||
file_current = self._create_file(
|
||||
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
|
||||
)
|
||||
|
||||
# Create file for different tenant
|
||||
other_tenant_id = str(uuid4())
|
||||
file_other = self._create_file(
|
||||
related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
|
||||
)
|
||||
|
||||
# Should raise ValueError on tenant mismatch
|
||||
with pytest.raises(ValueError) as context:
|
||||
self.loader.load_storage_keys([file_current, file_other])
|
||||
|
||||
assert "invalid file, expected tenant_id" in str(context.value)
|
||||
|
||||
def test_load_storage_keys_duplicate_file_ids(self):
|
||||
"""Test handling of duplicate file IDs in the batch."""
|
||||
# Create upload file
|
||||
upload_file = self._create_upload_file()
|
||||
|
||||
# Create two File objects with same related_id
|
||||
file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
|
||||
# Should handle duplicates gracefully
|
||||
self.loader.load_storage_keys([file1, file2])
|
||||
|
||||
# Both files should have the same storage key
|
||||
assert file1._storage_key == upload_file.key
|
||||
assert file2._storage_key == upload_file.key
|
||||
|
||||
def test_load_storage_keys_session_isolation(self):
|
||||
"""Test that the loader uses the provided session correctly."""
|
||||
# Create test data
|
||||
upload_file = self._create_upload_file()
|
||||
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
|
||||
# Create loader with different session (same underlying connection)
|
||||
|
||||
with Session(bind=db.engine) as other_session:
|
||||
other_loader = StorageKeyLoader(other_session, self.tenant_id)
|
||||
with pytest.raises(ValueError):
|
||||
other_loader.load_storage_keys([file])
|
0
api/tests/integration_tests/services/__init__.py
Normal file
0
api/tests/integration_tests/services/__init__.py
Normal file
@@ -0,0 +1,501 @@
|
||||
import json
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables.variables import StringVariable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.nodes import NodeType
|
||||
from factories.variable_factory import build_segment
|
||||
from libs import datetime_utils
|
||||
from models import db
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, VariableResetError, WorkflowDraftVariableService
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx")
|
||||
class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
_test_app_id: str
|
||||
_session: Session
|
||||
_node1_id = "test_node_1"
|
||||
_node2_id = "test_node_2"
|
||||
_node_exec_id = str(uuid.uuid4())
|
||||
|
||||
def setUp(self):
|
||||
self._test_app_id = str(uuid.uuid4())
|
||||
self._session: Session = db.session()
|
||||
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=self._test_app_id,
|
||||
name="sys_var",
|
||||
value=build_segment("sys_value"),
|
||||
node_execution_id=self._node_exec_id,
|
||||
)
|
||||
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=self._test_app_id,
|
||||
name="conv_var",
|
||||
value=build_segment("conv_value"),
|
||||
)
|
||||
node2_vars = [
|
||||
WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id=self._node2_id,
|
||||
name="int_var",
|
||||
value=build_segment(1),
|
||||
visible=False,
|
||||
node_execution_id=self._node_exec_id,
|
||||
),
|
||||
WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id=self._node2_id,
|
||||
name="str_var",
|
||||
value=build_segment("str_value"),
|
||||
visible=True,
|
||||
node_execution_id=self._node_exec_id,
|
||||
),
|
||||
]
|
||||
node1_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id=self._node1_id,
|
||||
name="str_var",
|
||||
value=build_segment("str_value"),
|
||||
visible=True,
|
||||
node_execution_id=self._node_exec_id,
|
||||
)
|
||||
_variables = list(node2_vars)
|
||||
_variables.extend(
|
||||
[
|
||||
node1_var,
|
||||
sys_var,
|
||||
conv_var,
|
||||
]
|
||||
)
|
||||
|
||||
db.session.add_all(_variables)
|
||||
db.session.flush()
|
||||
self._variable_ids = [v.id for v in _variables]
|
||||
self._node1_str_var_id = node1_var.id
|
||||
self._sys_var_id = sys_var.id
|
||||
self._conv_var_id = conv_var.id
|
||||
self._node2_var_ids = [v.id for v in node2_vars]
|
||||
|
||||
def _get_test_srv(self) -> WorkflowDraftVariableService:
|
||||
return WorkflowDraftVariableService(session=self._session)
|
||||
|
||||
def tearDown(self):
|
||||
self._session.rollback()
|
||||
|
||||
def test_list_variables(self):
|
||||
srv = self._get_test_srv()
|
||||
var_list = srv.list_variables_without_values(self._test_app_id, page=1, limit=2)
|
||||
assert var_list.total == 5
|
||||
assert len(var_list.variables) == 2
|
||||
page1_var_ids = {v.id for v in var_list.variables}
|
||||
assert page1_var_ids.issubset(self._variable_ids)
|
||||
|
||||
var_list_2 = srv.list_variables_without_values(self._test_app_id, page=2, limit=2)
|
||||
assert var_list_2.total is None
|
||||
assert len(var_list_2.variables) == 2
|
||||
page2_var_ids = {v.id for v in var_list_2.variables}
|
||||
assert page2_var_ids.isdisjoint(page1_var_ids)
|
||||
assert page2_var_ids.issubset(self._variable_ids)
|
||||
|
||||
def test_get_node_variable(self):
|
||||
srv = self._get_test_srv()
|
||||
node_var = srv.get_node_variable(self._test_app_id, self._node1_id, "str_var")
|
||||
assert node_var is not None
|
||||
assert node_var.id == self._node1_str_var_id
|
||||
assert node_var.name == "str_var"
|
||||
assert node_var.get_value() == build_segment("str_value")
|
||||
|
||||
def test_get_system_variable(self):
|
||||
srv = self._get_test_srv()
|
||||
sys_var = srv.get_system_variable(self._test_app_id, "sys_var")
|
||||
assert sys_var is not None
|
||||
assert sys_var.id == self._sys_var_id
|
||||
assert sys_var.name == "sys_var"
|
||||
assert sys_var.get_value() == build_segment("sys_value")
|
||||
|
||||
def test_get_conversation_variable(self):
|
||||
srv = self._get_test_srv()
|
||||
conv_var = srv.get_conversation_variable(self._test_app_id, "conv_var")
|
||||
assert conv_var is not None
|
||||
assert conv_var.id == self._conv_var_id
|
||||
assert conv_var.name == "conv_var"
|
||||
assert conv_var.get_value() == build_segment("conv_value")
|
||||
|
||||
def test_delete_node_variables(self):
|
||||
srv = self._get_test_srv()
|
||||
srv.delete_node_variables(self._test_app_id, self._node2_id)
|
||||
node2_var_count = (
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id == self._test_app_id,
|
||||
WorkflowDraftVariable.node_id == self._node2_id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
assert node2_var_count == 0
|
||||
|
||||
def test_delete_variable(self):
|
||||
srv = self._get_test_srv()
|
||||
node_1_var = (
|
||||
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).one()
|
||||
)
|
||||
srv.delete_variable(node_1_var)
|
||||
exists = bool(
|
||||
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).first()
|
||||
)
|
||||
assert exists is False
|
||||
|
||||
def test__list_node_variables(self):
|
||||
srv = self._get_test_srv()
|
||||
node_vars = srv._list_node_variables(self._test_app_id, self._node2_id)
|
||||
assert len(node_vars.variables) == 2
|
||||
assert {v.id for v in node_vars.variables} == set(self._node2_var_ids)
|
||||
|
||||
def test_get_draft_variables_by_selectors(self):
|
||||
srv = self._get_test_srv()
|
||||
selectors = [
|
||||
[self._node1_id, "str_var"],
|
||||
[self._node2_id, "str_var"],
|
||||
[self._node2_id, "int_var"],
|
||||
]
|
||||
variables = srv.get_draft_variables_by_selectors(self._test_app_id, selectors)
|
||||
assert len(variables) == 3
|
||||
assert {v.id for v in variables} == {self._node1_str_var_id} | set(self._node2_var_ids)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx")
|
||||
class TestDraftVariableLoader(unittest.TestCase):
|
||||
_test_app_id: str
|
||||
_test_tenant_id: str
|
||||
|
||||
_node1_id = "test_loader_node_1"
|
||||
_node_exec_id = str(uuid.uuid4())
|
||||
|
||||
def setUp(self):
|
||||
self._test_app_id = str(uuid.uuid4())
|
||||
self._test_tenant_id = str(uuid.uuid4())
|
||||
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=self._test_app_id,
|
||||
name="sys_var",
|
||||
value=build_segment("sys_value"),
|
||||
node_execution_id=self._node_exec_id,
|
||||
)
|
||||
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=self._test_app_id,
|
||||
name="conv_var",
|
||||
value=build_segment("conv_value"),
|
||||
)
|
||||
node_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id=self._node1_id,
|
||||
name="str_var",
|
||||
value=build_segment("str_value"),
|
||||
visible=True,
|
||||
node_execution_id=self._node_exec_id,
|
||||
)
|
||||
_variables = [
|
||||
node_var,
|
||||
sys_var,
|
||||
conv_var,
|
||||
]
|
||||
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
session.add_all(_variables)
|
||||
session.flush()
|
||||
session.commit()
|
||||
self._variable_ids = [v.id for v in _variables]
|
||||
self._node_var_id = node_var.id
|
||||
self._sys_var_id = sys_var.id
|
||||
self._conv_var_id = conv_var.id
|
||||
|
||||
def tearDown(self):
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.app_id == self._test_app_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
session.commit()
|
||||
|
||||
def test_variable_loader_with_empty_selector(self):
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
|
||||
variables = var_loader.load_variables([])
|
||||
assert len(variables) == 0
|
||||
|
||||
def test_variable_loader_with_non_empty_selector(self):
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
|
||||
variables = var_loader.load_variables(
|
||||
[
|
||||
[SYSTEM_VARIABLE_NODE_ID, "sys_var"],
|
||||
[CONVERSATION_VARIABLE_NODE_ID, "conv_var"],
|
||||
[self._node1_id, "str_var"],
|
||||
]
|
||||
)
|
||||
assert len(variables) == 3
|
||||
conv_var = next(v for v in variables if v.selector[0] == CONVERSATION_VARIABLE_NODE_ID)
|
||||
assert conv_var.id == self._conv_var_id
|
||||
sys_var = next(v for v in variables if v.selector[0] == SYSTEM_VARIABLE_NODE_ID)
|
||||
assert sys_var.id == self._sys_var_id
|
||||
node1_var = next(v for v in variables if v.selector[0] == self._node1_id)
|
||||
assert node1_var.id == self._node_var_id
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx")
|
||||
class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
"""Integration tests for reset_variable functionality using real database"""
|
||||
|
||||
_test_app_id: str
|
||||
_test_tenant_id: str
|
||||
_test_workflow_id: str
|
||||
_session: Session
|
||||
_node_id = "test_reset_node"
|
||||
_node_exec_id: str
|
||||
_workflow_node_exec_id: str
|
||||
|
||||
def setUp(self):
|
||||
self._test_app_id = str(uuid.uuid4())
|
||||
self._test_tenant_id = str(uuid.uuid4())
|
||||
self._test_workflow_id = str(uuid.uuid4())
|
||||
self._node_exec_id = str(uuid.uuid4())
|
||||
self._workflow_node_exec_id = str(uuid.uuid4())
|
||||
self._session: Session = db.session()
|
||||
|
||||
# Create a workflow node execution record with outputs
|
||||
# Note: The WorkflowNodeExecutionModel.id should match the node_execution_id in WorkflowDraftVariable
|
||||
self._workflow_node_execution = WorkflowNodeExecutionModel(
|
||||
id=self._node_exec_id, # This should match the node_execution_id in the variable
|
||||
tenant_id=self._test_tenant_id,
|
||||
app_id=self._test_app_id,
|
||||
workflow_id=self._test_workflow_id,
|
||||
triggered_from="workflow-run",
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
index=1,
|
||||
node_execution_id=self._node_exec_id,
|
||||
node_id=self._node_id,
|
||||
node_type=NodeType.LLM.value,
|
||||
title="Test Node",
|
||||
inputs='{"input": "test input"}',
|
||||
process_data='{"test_var": "process_value", "other_var": "other_process"}',
|
||||
outputs='{"test_var": "output_value", "other_var": "other_output"}',
|
||||
status="succeeded",
|
||||
elapsed_time=1.5,
|
||||
created_by_role="account",
|
||||
created_by=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
# Create conversation variables for the workflow
|
||||
self._conv_variables = [
|
||||
StringVariable(
|
||||
id=str(uuid.uuid4()),
|
||||
name="conv_var_1",
|
||||
description="Test conversation variable 1",
|
||||
value="default_value_1",
|
||||
),
|
||||
StringVariable(
|
||||
id=str(uuid.uuid4()),
|
||||
name="conv_var_2",
|
||||
description="Test conversation variable 2",
|
||||
value="default_value_2",
|
||||
),
|
||||
]
|
||||
|
||||
# Create test variables
|
||||
self._node_var_with_exec = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id=self._node_id,
|
||||
name="test_var",
|
||||
value=build_segment("old_value"),
|
||||
node_execution_id=self._node_exec_id,
|
||||
)
|
||||
self._node_var_with_exec.last_edited_at = datetime_utils.naive_utc_now()
|
||||
|
||||
self._node_var_without_exec = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id=self._node_id,
|
||||
name="no_exec_var",
|
||||
value=build_segment("some_value"),
|
||||
node_execution_id="temp_exec_id",
|
||||
)
|
||||
# Manually set node_execution_id to None after creation
|
||||
self._node_var_without_exec.node_execution_id = None
|
||||
|
||||
self._node_var_missing_exec = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id=self._node_id,
|
||||
name="missing_exec_var",
|
||||
value=build_segment("some_value"),
|
||||
node_execution_id=str(uuid.uuid4()), # Use a valid UUID that doesn't exist in database
|
||||
)
|
||||
|
||||
self._conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=self._test_app_id,
|
||||
name="conv_var_1",
|
||||
value=build_segment("old_conv_value"),
|
||||
)
|
||||
self._conv_var.last_edited_at = datetime_utils.naive_utc_now()
|
||||
|
||||
# Add all to database
|
||||
db.session.add_all(
|
||||
[
|
||||
self._workflow_node_execution,
|
||||
self._node_var_with_exec,
|
||||
self._node_var_without_exec,
|
||||
self._node_var_missing_exec,
|
||||
self._conv_var,
|
||||
]
|
||||
)
|
||||
db.session.flush()
|
||||
|
||||
# Store IDs for assertions
|
||||
self._node_var_with_exec_id = self._node_var_with_exec.id
|
||||
self._node_var_without_exec_id = self._node_var_without_exec.id
|
||||
self._node_var_missing_exec_id = self._node_var_missing_exec.id
|
||||
self._conv_var_id = self._conv_var.id
|
||||
|
||||
def _get_test_srv(self) -> WorkflowDraftVariableService:
|
||||
return WorkflowDraftVariableService(session=self._session)
|
||||
|
||||
def _create_mock_workflow(self) -> Workflow:
|
||||
"""Create a real workflow with conversation variables and graph"""
|
||||
conversation_vars = self._conv_variables
|
||||
|
||||
# Create a simple graph with the test node
|
||||
graph = {
|
||||
"nodes": [{"id": "test_reset_node", "type": "llm", "title": "Test Node", "data": {"type": "llm"}}],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
workflow = Workflow.new(
|
||||
tenant_id=str(uuid.uuid4()),
|
||||
app_id=self._test_app_id,
|
||||
type="workflow",
|
||||
version="1.0",
|
||||
graph=json.dumps(graph),
|
||||
features="{}",
|
||||
created_by=str(uuid.uuid4()),
|
||||
environment_variables=[],
|
||||
conversation_variables=conversation_vars,
|
||||
)
|
||||
return workflow
|
||||
|
||||
def tearDown(self):
|
||||
self._session.rollback()
|
||||
|
||||
def test_reset_node_variable_with_valid_execution_record(self):
|
||||
"""Test resetting a node variable with valid execution record - should restore from execution"""
|
||||
srv = self._get_test_srv()
|
||||
mock_workflow = self._create_mock_workflow()
|
||||
|
||||
# Get the variable before reset
|
||||
variable = srv.get_variable(self._node_var_with_exec_id)
|
||||
assert variable is not None
|
||||
assert variable.get_value().value == "old_value"
|
||||
assert variable.last_edited_at is not None
|
||||
|
||||
# Reset the variable
|
||||
result = srv.reset_variable(mock_workflow, variable)
|
||||
|
||||
# Should return the updated variable
|
||||
assert result is not None
|
||||
assert result.id == self._node_var_with_exec_id
|
||||
assert result.node_execution_id == self._workflow_node_execution.id
|
||||
assert result.last_edited_at is None # Should be reset to None
|
||||
|
||||
# The returned variable should have the updated value from execution record
|
||||
assert result.get_value().value == "output_value"
|
||||
|
||||
# Verify the variable was updated in database
|
||||
updated_variable = srv.get_variable(self._node_var_with_exec_id)
|
||||
assert updated_variable is not None
|
||||
# The value should be updated from the execution record's outputs
|
||||
assert updated_variable.get_value().value == "output_value"
|
||||
assert updated_variable.last_edited_at is None
|
||||
assert updated_variable.node_execution_id == self._workflow_node_execution.id
|
||||
|
||||
def test_reset_node_variable_with_no_execution_id(self):
|
||||
"""Test resetting a node variable with no execution ID - should delete variable"""
|
||||
srv = self._get_test_srv()
|
||||
mock_workflow = self._create_mock_workflow()
|
||||
|
||||
# Get the variable before reset
|
||||
variable = srv.get_variable(self._node_var_without_exec_id)
|
||||
assert variable is not None
|
||||
|
||||
# Reset the variable
|
||||
result = srv.reset_variable(mock_workflow, variable)
|
||||
|
||||
# Should return None (variable deleted)
|
||||
assert result is None
|
||||
|
||||
# Verify the variable was deleted
|
||||
deleted_variable = srv.get_variable(self._node_var_without_exec_id)
|
||||
assert deleted_variable is None
|
||||
|
||||
def test_reset_node_variable_with_missing_execution_record(self):
|
||||
"""Test resetting a node variable when execution record doesn't exist"""
|
||||
srv = self._get_test_srv()
|
||||
mock_workflow = self._create_mock_workflow()
|
||||
|
||||
# Get the variable before reset
|
||||
variable = srv.get_variable(self._node_var_missing_exec_id)
|
||||
assert variable is not None
|
||||
|
||||
# Reset the variable
|
||||
result = srv.reset_variable(mock_workflow, variable)
|
||||
|
||||
# Should return None (variable deleted)
|
||||
assert result is None
|
||||
|
||||
# Verify the variable was deleted
|
||||
deleted_variable = srv.get_variable(self._node_var_missing_exec_id)
|
||||
assert deleted_variable is None
|
||||
|
||||
def test_reset_conversation_variable(self):
|
||||
"""Test resetting a conversation variable"""
|
||||
srv = self._get_test_srv()
|
||||
mock_workflow = self._create_mock_workflow()
|
||||
|
||||
# Get the variable before reset
|
||||
variable = srv.get_variable(self._conv_var_id)
|
||||
assert variable is not None
|
||||
assert variable.get_value().value == "old_conv_value"
|
||||
assert variable.last_edited_at is not None
|
||||
|
||||
# Reset the variable
|
||||
result = srv.reset_variable(mock_workflow, variable)
|
||||
|
||||
# Should return the updated variable
|
||||
assert result is not None
|
||||
assert result.id == self._conv_var_id
|
||||
assert result.last_edited_at is None # Should be reset to None
|
||||
|
||||
# Verify the variable was updated with default value from workflow
|
||||
updated_variable = srv.get_variable(self._conv_var_id)
|
||||
assert updated_variable is not None
|
||||
# The value should be updated from the workflow's conversation variable default
|
||||
assert updated_variable.get_value().value == "default_value_1"
|
||||
assert updated_variable.last_edited_at is None
|
||||
|
||||
def test_reset_system_variable_raises_error(self):
|
||||
"""Test that resetting a system variable raises an error"""
|
||||
srv = self._get_test_srv()
|
||||
mock_workflow = self._create_mock_workflow()
|
||||
|
||||
# Create a system variable
|
||||
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=self._test_app_id,
|
||||
name="sys_var",
|
||||
value=build_segment("sys_value"),
|
||||
node_execution_id=self._node_exec_id,
|
||||
)
|
||||
db.session.add(sys_var)
|
||||
db.session.flush()
|
||||
|
||||
# Attempt to reset the system variable
|
||||
with pytest.raises(VariableResetError) as exc_info:
|
||||
srv.reset_variable(mock_workflow, sys_var)
|
||||
|
||||
assert "cannot reset system variable" in str(exc_info.value)
|
||||
assert sys_var.id in str(exc_info.value)
|
@@ -8,8 +8,6 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app_factory import create_app
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
@@ -30,21 +28,6 @@ from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_mod
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def app():
|
||||
# Set up storage configuration
|
||||
os.environ["STORAGE_TYPE"] = "opendal"
|
||||
os.environ["OPENDAL_SCHEME"] = "fs"
|
||||
os.environ["OPENDAL_FS_ROOT"] = "storage"
|
||||
|
||||
# Ensure storage directory exists
|
||||
os.makedirs("storage", exist_ok=True)
|
||||
|
||||
app = create_app()
|
||||
dify_config.LOGIN_DISABLED = True
|
||||
return app
|
||||
|
||||
|
||||
def init_llm_node(config: dict) -> LLMNode:
|
||||
graph_config = {
|
||||
"edges": [
|
||||
@@ -102,197 +85,195 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
return node
|
||||
|
||||
|
||||
def test_execute_llm(app):
|
||||
with app.app_context():
|
||||
node = init_llm_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "llm",
|
||||
"model": {
|
||||
"provider": "langgenius/openai/openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
},
|
||||
"prompt_template": [
|
||||
{
|
||||
"role": "system",
|
||||
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.",
|
||||
},
|
||||
{"role": "user", "text": "{{#sys.query#}}"},
|
||||
],
|
||||
"memory": None,
|
||||
"context": {"enabled": False},
|
||||
"vision": {"enabled": False},
|
||||
def test_execute_llm(flask_req_ctx):
|
||||
node = init_llm_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "llm",
|
||||
"model": {
|
||||
"provider": "langgenius/openai/openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
},
|
||||
"prompt_template": [
|
||||
{
|
||||
"role": "system",
|
||||
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.",
|
||||
},
|
||||
{"role": "user", "text": "{{#sys.query#}}"},
|
||||
],
|
||||
"memory": None,
|
||||
"context": {"enabled": False},
|
||||
"vision": {"enabled": False},
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
|
||||
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
|
||||
|
||||
# Create a proper LLM result with real entities
|
||||
mock_usage = LLMUsage(
|
||||
prompt_tokens=30,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("1000"),
|
||||
prompt_price=Decimal("0.00003"),
|
||||
completion_tokens=20,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("1000"),
|
||||
completion_price=Decimal("0.00004"),
|
||||
total_tokens=50,
|
||||
total_price=Decimal("0.00007"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
# Create a proper LLM result with real entities
|
||||
mock_usage = LLMUsage(
|
||||
prompt_tokens=30,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("1000"),
|
||||
prompt_price=Decimal("0.00003"),
|
||||
completion_tokens=20,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("1000"),
|
||||
completion_price=Decimal("0.00004"),
|
||||
total_tokens=50,
|
||||
total_price=Decimal("0.00007"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")
|
||||
mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")
|
||||
|
||||
mock_llm_result = LLMResult(
|
||||
model="gpt-3.5-turbo",
|
||||
prompt_messages=[],
|
||||
message=mock_message,
|
||||
usage=mock_usage,
|
||||
)
|
||||
mock_llm_result = LLMResult(
|
||||
model="gpt-3.5-turbo",
|
||||
prompt_messages=[],
|
||||
message=mock_message,
|
||||
usage=mock_usage,
|
||||
)
|
||||
|
||||
# Create a simple mock model instance that doesn't call real providers
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||
# Create a simple mock model instance that doesn't call real providers
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||
|
||||
# Create a simple mock model config with required attributes
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.mode = "chat"
|
||||
mock_model_config.provider = "langgenius/openai/openai"
|
||||
mock_model_config.model = "gpt-3.5-turbo"
|
||||
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
|
||||
# Create a simple mock model config with required attributes
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.mode = "chat"
|
||||
mock_model_config.provider = "langgenius/openai/openai"
|
||||
mock_model_config.model = "gpt-3.5-turbo"
|
||||
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
|
||||
|
||||
# Mock the _fetch_model_config method
|
||||
def mock_fetch_model_config_func(_node_data_model):
|
||||
return mock_model_instance, mock_model_config
|
||||
# Mock the _fetch_model_config method
|
||||
def mock_fetch_model_config_func(_node_data_model):
|
||||
return mock_model_instance, mock_model_config
|
||||
|
||||
# Also mock ModelManager.get_model_instance to avoid database calls
|
||||
def mock_get_model_instance(_self, **kwargs):
|
||||
return mock_model_instance
|
||||
# Also mock ModelManager.get_model_instance to avoid database calls
|
||||
def mock_get_model_instance(_self, **kwargs):
|
||||
return mock_model_instance
|
||||
|
||||
with (
|
||||
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
|
||||
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
|
||||
):
|
||||
# execute node
|
||||
result = node._run()
|
||||
assert isinstance(result, Generator)
|
||||
with (
|
||||
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
|
||||
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
|
||||
):
|
||||
# execute node
|
||||
result = node._run()
|
||||
assert isinstance(result, Generator)
|
||||
|
||||
for item in result:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.process_data is not None
|
||||
assert item.run_result.outputs is not None
|
||||
assert item.run_result.outputs.get("text") is not None
|
||||
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
|
||||
for item in result:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.process_data is not None
|
||||
assert item.run_result.outputs is not None
|
||||
assert item.run_result.outputs.get("text") is not None
|
||||
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
|
||||
def test_execute_llm_with_jinja2(app, setup_code_executor_mock):
|
||||
def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
|
||||
"""
|
||||
Test execute LLM node with jinja2
|
||||
"""
|
||||
with app.app_context():
|
||||
node = init_llm_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "llm",
|
||||
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
|
||||
"prompt_config": {
|
||||
"jinja2_variables": [
|
||||
{"variable": "sys_query", "value_selector": ["sys", "query"]},
|
||||
{"variable": "output", "value_selector": ["abc", "output"]},
|
||||
]
|
||||
},
|
||||
"prompt_template": [
|
||||
{
|
||||
"role": "system",
|
||||
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
|
||||
"jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
|
||||
"edition_type": "jinja2",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"text": "{{#sys.query#}}",
|
||||
"jinja2_text": "{{sys_query}}",
|
||||
"edition_type": "basic",
|
||||
},
|
||||
],
|
||||
"memory": None,
|
||||
"context": {"enabled": False},
|
||||
"vision": {"enabled": False},
|
||||
node = init_llm_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "llm",
|
||||
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
|
||||
"prompt_config": {
|
||||
"jinja2_variables": [
|
||||
{"variable": "sys_query", "value_selector": ["sys", "query"]},
|
||||
{"variable": "output", "value_selector": ["abc", "output"]},
|
||||
]
|
||||
},
|
||||
"prompt_template": [
|
||||
{
|
||||
"role": "system",
|
||||
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
|
||||
"jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
|
||||
"edition_type": "jinja2",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"text": "{{#sys.query#}}",
|
||||
"jinja2_text": "{{sys_query}}",
|
||||
"edition_type": "basic",
|
||||
},
|
||||
],
|
||||
"memory": None,
|
||||
"context": {"enabled": False},
|
||||
"vision": {"enabled": False},
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# Create a proper LLM result with real entities
|
||||
mock_usage = LLMUsage(
|
||||
prompt_tokens=30,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("1000"),
|
||||
prompt_price=Decimal("0.00003"),
|
||||
completion_tokens=20,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("1000"),
|
||||
completion_price=Decimal("0.00004"),
|
||||
total_tokens=50,
|
||||
total_price=Decimal("0.00007"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
# Create a proper LLM result with real entities
|
||||
mock_usage = LLMUsage(
|
||||
prompt_tokens=30,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("1000"),
|
||||
prompt_price=Decimal("0.00003"),
|
||||
completion_tokens=20,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("1000"),
|
||||
completion_price=Decimal("0.00004"),
|
||||
total_tokens=50,
|
||||
total_price=Decimal("0.00007"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
|
||||
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
|
||||
|
||||
mock_llm_result = LLMResult(
|
||||
model="gpt-3.5-turbo",
|
||||
prompt_messages=[],
|
||||
message=mock_message,
|
||||
usage=mock_usage,
|
||||
)
|
||||
mock_llm_result = LLMResult(
|
||||
model="gpt-3.5-turbo",
|
||||
prompt_messages=[],
|
||||
message=mock_message,
|
||||
usage=mock_usage,
|
||||
)
|
||||
|
||||
# Create a simple mock model instance that doesn't call real providers
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||
# Create a simple mock model instance that doesn't call real providers
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||
|
||||
# Create a simple mock model config with required attributes
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.mode = "chat"
|
||||
mock_model_config.provider = "openai"
|
||||
mock_model_config.model = "gpt-3.5-turbo"
|
||||
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
|
||||
# Create a simple mock model config with required attributes
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.mode = "chat"
|
||||
mock_model_config.provider = "openai"
|
||||
mock_model_config.model = "gpt-3.5-turbo"
|
||||
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
|
||||
|
||||
# Mock the _fetch_model_config method
|
||||
def mock_fetch_model_config_func(_node_data_model):
|
||||
return mock_model_instance, mock_model_config
|
||||
# Mock the _fetch_model_config method
|
||||
def mock_fetch_model_config_func(_node_data_model):
|
||||
return mock_model_instance, mock_model_config
|
||||
|
||||
# Also mock ModelManager.get_model_instance to avoid database calls
|
||||
def mock_get_model_instance(_self, **kwargs):
|
||||
return mock_model_instance
|
||||
# Also mock ModelManager.get_model_instance to avoid database calls
|
||||
def mock_get_model_instance(_self, **kwargs):
|
||||
return mock_model_instance
|
||||
|
||||
with (
|
||||
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
|
||||
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
|
||||
):
|
||||
# execute node
|
||||
result = node._run()
|
||||
with (
|
||||
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
|
||||
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
|
||||
):
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
||||
for item in result:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.process_data is not None
|
||||
assert "sunny" in json.dumps(item.run_result.process_data)
|
||||
assert "what's the weather today?" in json.dumps(item.run_result.process_data)
|
||||
for item in result:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.process_data is not None
|
||||
assert "sunny" in json.dumps(item.run_result.process_data)
|
||||
assert "what's the weather today?" in json.dumps(item.run_result.process_data)
|
||||
|
||||
|
||||
def test_extract_json():
|
||||
|
@@ -0,0 +1,302 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from flask_restful import marshal
|
||||
|
||||
from controllers.console.app.workflow_draft_variable import (
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS,
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS,
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS,
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||
)
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from factories.variable_factory import build_segment
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||
|
||||
_TEST_APP_ID = "test_app_id"
|
||||
_TEST_NODE_EXEC_ID = str(uuid.uuid4())
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableFields:
|
||||
def test_conversation_variable(self):
|
||||
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1)
|
||||
)
|
||||
|
||||
conv_var.id = str(uuid.uuid4())
|
||||
conv_var.visible = True
|
||||
|
||||
expected_without_value: OrderedDict[str, Any] = OrderedDict(
|
||||
{
|
||||
"id": str(conv_var.id),
|
||||
"type": conv_var.get_variable_type().value,
|
||||
"name": "conv_var",
|
||||
"description": "",
|
||||
"selector": [CONVERSATION_VARIABLE_NODE_ID, "conv_var"],
|
||||
"value_type": "number",
|
||||
"edited": False,
|
||||
"visible": True,
|
||||
}
|
||||
)
|
||||
|
||||
assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
|
||||
expected_with_value = expected_without_value.copy()
|
||||
expected_with_value["value"] = 1
|
||||
assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
|
||||
|
||||
def test_create_sys_variable(self):
|
||||
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=_TEST_APP_ID,
|
||||
name="sys_var",
|
||||
value=build_segment("a"),
|
||||
editable=True,
|
||||
node_execution_id=_TEST_NODE_EXEC_ID,
|
||||
)
|
||||
|
||||
sys_var.id = str(uuid.uuid4())
|
||||
sys_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
sys_var.visible = True
|
||||
|
||||
expected_without_value = OrderedDict(
|
||||
{
|
||||
"id": str(sys_var.id),
|
||||
"type": sys_var.get_variable_type().value,
|
||||
"name": "sys_var",
|
||||
"description": "",
|
||||
"selector": [SYSTEM_VARIABLE_NODE_ID, "sys_var"],
|
||||
"value_type": "string",
|
||||
"edited": True,
|
||||
"visible": True,
|
||||
}
|
||||
)
|
||||
assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
|
||||
expected_with_value = expected_without_value.copy()
|
||||
expected_with_value["value"] = "a"
|
||||
assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
|
||||
|
||||
def test_node_variable(self):
|
||||
node_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=_TEST_APP_ID,
|
||||
node_id="test_node",
|
||||
name="node_var",
|
||||
value=build_segment([1, "a"]),
|
||||
visible=False,
|
||||
node_execution_id=_TEST_NODE_EXEC_ID,
|
||||
)
|
||||
|
||||
node_var.id = str(uuid.uuid4())
|
||||
node_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
|
||||
expected_without_value: OrderedDict[str, Any] = OrderedDict(
|
||||
{
|
||||
"id": str(node_var.id),
|
||||
"type": node_var.get_variable_type().value,
|
||||
"name": "node_var",
|
||||
"description": "",
|
||||
"selector": ["test_node", "node_var"],
|
||||
"value_type": "array[any]",
|
||||
"edited": True,
|
||||
"visible": False,
|
||||
}
|
||||
)
|
||||
|
||||
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
|
||||
expected_with_value = expected_without_value.copy()
|
||||
expected_with_value["value"] = [1, "a"]
|
||||
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableList:
|
||||
def test_workflow_draft_variable_list(self):
|
||||
class TestCase(NamedTuple):
|
||||
name: str
|
||||
var_list: WorkflowDraftVariableList
|
||||
expected: dict
|
||||
|
||||
node_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=_TEST_APP_ID,
|
||||
node_id="test_node",
|
||||
name="test_var",
|
||||
value=build_segment("a"),
|
||||
visible=True,
|
||||
node_execution_id=_TEST_NODE_EXEC_ID,
|
||||
)
|
||||
node_var.id = str(uuid.uuid4())
|
||||
node_var_dict = OrderedDict(
|
||||
{
|
||||
"id": str(node_var.id),
|
||||
"type": node_var.get_variable_type().value,
|
||||
"name": "test_var",
|
||||
"description": "",
|
||||
"selector": ["test_node", "test_var"],
|
||||
"value_type": "string",
|
||||
"edited": False,
|
||||
"visible": True,
|
||||
}
|
||||
)
|
||||
|
||||
cases = [
|
||||
TestCase(
|
||||
name="empty variable list",
|
||||
var_list=WorkflowDraftVariableList(variables=[]),
|
||||
expected=OrderedDict(
|
||||
{
|
||||
"items": [],
|
||||
"total": None,
|
||||
}
|
||||
),
|
||||
),
|
||||
TestCase(
|
||||
name="empty variable list with total",
|
||||
var_list=WorkflowDraftVariableList(variables=[], total=10),
|
||||
expected=OrderedDict(
|
||||
{
|
||||
"items": [],
|
||||
"total": 10,
|
||||
}
|
||||
),
|
||||
),
|
||||
TestCase(
|
||||
name="non-empty variable list",
|
||||
var_list=WorkflowDraftVariableList(variables=[node_var], total=None),
|
||||
expected=OrderedDict(
|
||||
{
|
||||
"items": [node_var_dict],
|
||||
"total": None,
|
||||
}
|
||||
),
|
||||
),
|
||||
TestCase(
|
||||
name="non-empty variable list with total",
|
||||
var_list=WorkflowDraftVariableList(variables=[node_var], total=10),
|
||||
expected=OrderedDict(
|
||||
{
|
||||
"items": [node_var_dict],
|
||||
"total": 10,
|
||||
}
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
for idx, case in enumerate(cases, 1):
|
||||
assert marshal(case.var_list, _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) == case.expected, (
|
||||
f"Test case {idx} failed, {case.name=}"
|
||||
)
|
||||
|
||||
|
||||
def test_workflow_node_variables_fields():
|
||||
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1)
|
||||
)
|
||||
resp = marshal(WorkflowDraftVariableList(variables=[conv_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
assert isinstance(resp, dict)
|
||||
assert len(resp["items"]) == 1
|
||||
item_dict = resp["items"][0]
|
||||
assert item_dict["name"] == "conv_var"
|
||||
assert item_dict["value"] == 1
|
||||
|
||||
|
||||
def test_workflow_file_variable_with_signed_url():
|
||||
"""Test that File type variables include signed URLs in API responses."""
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.file.models import File
|
||||
|
||||
# Create a File object with LOCAL_FILE transfer method (which generates signed URLs)
|
||||
test_file = File(
|
||||
id="test_file_id",
|
||||
tenant_id="test_tenant_id",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="test_upload_file_id",
|
||||
filename="test.jpg",
|
||||
extension=".jpg",
|
||||
mime_type="image/jpeg",
|
||||
size=12345,
|
||||
)
|
||||
|
||||
# Create a WorkflowDraftVariable with the File
|
||||
file_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=_TEST_APP_ID,
|
||||
node_id="test_node",
|
||||
name="file_var",
|
||||
value=build_segment(test_file),
|
||||
node_execution_id=_TEST_NODE_EXEC_ID,
|
||||
)
|
||||
|
||||
# Marshal the variable using the API fields
|
||||
resp = marshal(WorkflowDraftVariableList(variables=[file_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
|
||||
# Verify the response structure
|
||||
assert isinstance(resp, dict)
|
||||
assert len(resp["items"]) == 1
|
||||
item_dict = resp["items"][0]
|
||||
assert item_dict["name"] == "file_var"
|
||||
|
||||
# Verify the value is a dict (File.to_dict() result) and contains expected fields
|
||||
value = item_dict["value"]
|
||||
assert isinstance(value, dict)
|
||||
|
||||
# Verify the File fields are preserved
|
||||
assert value["id"] == test_file.id
|
||||
assert value["filename"] == test_file.filename
|
||||
assert value["type"] == test_file.type.value
|
||||
assert value["transfer_method"] == test_file.transfer_method.value
|
||||
assert value["size"] == test_file.size
|
||||
|
||||
# Verify the URL is present (it should be a signed URL for LOCAL_FILE transfer method)
|
||||
remote_url = value["remote_url"]
|
||||
assert remote_url is not None
|
||||
|
||||
assert isinstance(remote_url, str)
|
||||
# For LOCAL_FILE, the URL should contain signature parameters
|
||||
assert "timestamp=" in remote_url
|
||||
assert "nonce=" in remote_url
|
||||
assert "sign=" in remote_url
|
||||
|
||||
|
||||
def test_workflow_file_variable_remote_url():
|
||||
"""Test that File type variables with REMOTE_URL transfer method return the remote URL."""
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.file.models import File
|
||||
|
||||
# Create a File object with REMOTE_URL transfer method
|
||||
test_file = File(
|
||||
id="test_file_id",
|
||||
tenant_id="test_tenant_id",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/test.jpg",
|
||||
filename="test.jpg",
|
||||
extension=".jpg",
|
||||
mime_type="image/jpeg",
|
||||
size=12345,
|
||||
)
|
||||
|
||||
# Create a WorkflowDraftVariable with the File
|
||||
file_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=_TEST_APP_ID,
|
||||
node_id="test_node",
|
||||
name="file_var",
|
||||
value=build_segment(test_file),
|
||||
node_execution_id=_TEST_NODE_EXEC_ID,
|
||||
)
|
||||
|
||||
# Marshal the variable using the API fields
|
||||
resp = marshal(WorkflowDraftVariableList(variables=[file_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
|
||||
# Verify the response structure
|
||||
assert isinstance(resp, dict)
|
||||
assert len(resp["items"]) == 1
|
||||
item_dict = resp["items"][0]
|
||||
assert item_dict["name"] == "file_var"
|
||||
|
||||
# Verify the value is a dict (File.to_dict() result) and contains expected fields
|
||||
value = item_dict["value"]
|
||||
assert isinstance(value, dict)
|
||||
remote_url = value["remote_url"]
|
||||
|
||||
# For REMOTE_URL, the URL should be the original remote URL
|
||||
assert remote_url == test_file.remote_url
|
@@ -1,165 +0,0 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.variables import (
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
ObjectSegment,
|
||||
SecretVariable,
|
||||
StringVariable,
|
||||
)
|
||||
from core.variables.exc import VariableError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from factories import variable_factory
|
||||
|
||||
|
||||
def test_string_variable():
|
||||
test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"}
|
||||
result = variable_factory.build_conversation_variable_from_mapping(test_data)
|
||||
assert isinstance(result, StringVariable)
|
||||
|
||||
|
||||
def test_integer_variable():
|
||||
test_data = {"value_type": "number", "name": "test_int", "value": 42}
|
||||
result = variable_factory.build_conversation_variable_from_mapping(test_data)
|
||||
assert isinstance(result, IntegerVariable)
|
||||
|
||||
|
||||
def test_float_variable():
|
||||
test_data = {"value_type": "number", "name": "test_float", "value": 3.14}
|
||||
result = variable_factory.build_conversation_variable_from_mapping(test_data)
|
||||
assert isinstance(result, FloatVariable)
|
||||
|
||||
|
||||
def test_secret_variable():
|
||||
test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"}
|
||||
result = variable_factory.build_conversation_variable_from_mapping(test_data)
|
||||
assert isinstance(result, SecretVariable)
|
||||
|
||||
|
||||
def test_invalid_value_type():
|
||||
test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"}
|
||||
with pytest.raises(VariableError):
|
||||
variable_factory.build_conversation_variable_from_mapping(test_data)
|
||||
|
||||
|
||||
def test_build_a_blank_string():
|
||||
result = variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"value_type": "string",
|
||||
"name": "blank",
|
||||
"value": "",
|
||||
}
|
||||
)
|
||||
assert isinstance(result, StringVariable)
|
||||
assert result.value == ""
|
||||
|
||||
|
||||
def test_build_a_object_variable_with_none_value():
|
||||
var = variable_factory.build_segment(
|
||||
{
|
||||
"key1": None,
|
||||
}
|
||||
)
|
||||
assert isinstance(var, ObjectSegment)
|
||||
assert var.value["key1"] is None
|
||||
|
||||
|
||||
def test_object_variable():
|
||||
mapping = {
|
||||
"id": str(uuid4()),
|
||||
"value_type": "object",
|
||||
"name": "test_object",
|
||||
"description": "Description of the variable.",
|
||||
"value": {
|
||||
"key1": "text",
|
||||
"key2": 2,
|
||||
},
|
||||
}
|
||||
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ObjectSegment)
|
||||
assert isinstance(variable.value["key1"], str)
|
||||
assert isinstance(variable.value["key2"], int)
|
||||
|
||||
|
||||
def test_array_string_variable():
|
||||
mapping = {
|
||||
"id": str(uuid4()),
|
||||
"value_type": "array[string]",
|
||||
"name": "test_array",
|
||||
"description": "Description of the variable.",
|
||||
"value": [
|
||||
"text",
|
||||
"text",
|
||||
],
|
||||
}
|
||||
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ArrayStringVariable)
|
||||
assert isinstance(variable.value[0], str)
|
||||
assert isinstance(variable.value[1], str)
|
||||
|
||||
|
||||
def test_array_number_variable():
|
||||
mapping = {
|
||||
"id": str(uuid4()),
|
||||
"value_type": "array[number]",
|
||||
"name": "test_array",
|
||||
"description": "Description of the variable.",
|
||||
"value": [
|
||||
1,
|
||||
2.0,
|
||||
],
|
||||
}
|
||||
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ArrayNumberVariable)
|
||||
assert isinstance(variable.value[0], int)
|
||||
assert isinstance(variable.value[1], float)
|
||||
|
||||
|
||||
def test_array_object_variable():
|
||||
mapping = {
|
||||
"id": str(uuid4()),
|
||||
"value_type": "array[object]",
|
||||
"name": "test_array",
|
||||
"description": "Description of the variable.",
|
||||
"value": [
|
||||
{
|
||||
"key1": "text",
|
||||
"key2": 1,
|
||||
},
|
||||
{
|
||||
"key1": "text",
|
||||
"key2": 1,
|
||||
},
|
||||
],
|
||||
}
|
||||
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ArrayObjectVariable)
|
||||
assert isinstance(variable.value[0], dict)
|
||||
assert isinstance(variable.value[1], dict)
|
||||
assert isinstance(variable.value[0]["key1"], str)
|
||||
assert isinstance(variable.value[0]["key2"], int)
|
||||
assert isinstance(variable.value[1]["key1"], str)
|
||||
assert isinstance(variable.value[1]["key2"], int)
|
||||
|
||||
|
||||
def test_variable_cannot_large_than_200_kb():
|
||||
with pytest.raises(VariableError):
|
||||
variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"id": str(uuid4()),
|
||||
"value_type": "string",
|
||||
"name": "test_text",
|
||||
"value": "a" * 1024 * 201,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_array_none_variable():
|
||||
var = variable_factory.build_segment([None, None, None, None])
|
||||
assert isinstance(var, ArrayAnySegment)
|
||||
assert var.value == [None, None, None, None]
|
25
api/tests/unit_tests/core/file/test_models.py
Normal file
25
api/tests/unit_tests/core/file/test_models.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
|
||||
|
||||
def test_file():
|
||||
file = File(
|
||||
id="test-file",
|
||||
tenant_id="test-tenant-id",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id="test-related-id",
|
||||
filename="image.png",
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
size=67,
|
||||
storage_key="test-storage-key",
|
||||
url="https://example.com/image.png",
|
||||
)
|
||||
assert file.tenant_id == "test-tenant-id"
|
||||
assert file.type == FileType.IMAGE
|
||||
assert file.transfer_method == FileTransferMethod.TOOL_FILE
|
||||
assert file.related_id == "test-related-id"
|
||||
assert file.filename == "image.png"
|
||||
assert file.extension == ".png"
|
||||
assert file.mime_type == "image/png"
|
||||
assert file.size == 67
|
@@ -3,6 +3,7 @@ import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables.segments import ArrayAnySegment, ArrayStringSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
@@ -197,7 +198,7 @@ def test_run():
|
||||
count += 1
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
|
||||
assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
|
||||
|
||||
assert count == 20
|
||||
|
||||
@@ -413,7 +414,7 @@ def test_run_parallel():
|
||||
count += 1
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
|
||||
assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
|
||||
|
||||
assert count == 32
|
||||
|
||||
@@ -654,7 +655,7 @@ def test_iteration_run_in_parallel_mode():
|
||||
parallel_arr.append(item)
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
|
||||
assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
|
||||
assert count == 32
|
||||
|
||||
for item in sequential_result:
|
||||
@@ -662,7 +663,7 @@ def test_iteration_run_in_parallel_mode():
|
||||
count += 1
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
|
||||
assert item.run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
|
||||
assert count == 64
|
||||
|
||||
|
||||
@@ -846,7 +847,7 @@ def test_iteration_run_error_handle():
|
||||
count += 1
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs == {"output": [None, None]}
|
||||
assert item.run_result.outputs == {"output": ArrayAnySegment(value=[None, None])}
|
||||
|
||||
assert count == 14
|
||||
# execute remove abnormal output
|
||||
@@ -857,5 +858,5 @@ def test_iteration_run_error_handle():
|
||||
count += 1
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs == {"output": []}
|
||||
assert item.run_result.outputs == {"output": ArrayAnySegment(value=[])}
|
||||
assert count == 14
|
||||
|
@@ -7,6 +7,7 @@ from docx.oxml.text.paragraph import CT_P
|
||||
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.segments import ArrayStringSegment
|
||||
from core.variables.variables import StringVariable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
@@ -69,7 +70,13 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s
|
||||
@pytest.mark.parametrize(
|
||||
("mime_type", "file_content", "expected_text", "transfer_method", "extension"),
|
||||
[
|
||||
("text/plain", b"Hello, world!", ["Hello, world!"], FileTransferMethod.LOCAL_FILE, ".txt"),
|
||||
(
|
||||
"text/plain",
|
||||
b"Hello, world!",
|
||||
["Hello, world!"],
|
||||
FileTransferMethod.LOCAL_FILE,
|
||||
".txt",
|
||||
),
|
||||
(
|
||||
"application/pdf",
|
||||
b"%PDF-1.5\n%Test PDF content",
|
||||
@@ -84,7 +91,13 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s
|
||||
FileTransferMethod.REMOTE_URL,
|
||||
"",
|
||||
),
|
||||
("text/plain", b"Remote content", ["Remote content"], FileTransferMethod.REMOTE_URL, None),
|
||||
(
|
||||
"text/plain",
|
||||
b"Remote content",
|
||||
["Remote content"],
|
||||
FileTransferMethod.REMOTE_URL,
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_run_extract_text(
|
||||
@@ -131,7 +144,7 @@ def test_run_extract_text(
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["text"] == expected_text
|
||||
assert result.outputs["text"] == ArrayStringSegment(value=expected_text)
|
||||
|
||||
if transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
mock_ssrf_proxy_get.assert_called_once_with("https://example.com/file.txt")
|
||||
|
@@ -115,7 +115,7 @@ def test_filter_files_by_type(list_operator_node):
|
||||
},
|
||||
]
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
for expected_file, result_file in zip(expected_files, result.outputs["result"]):
|
||||
for expected_file, result_file in zip(expected_files, result.outputs["result"].value):
|
||||
assert expected_file["filename"] == result_file.filename
|
||||
assert expected_file["type"] == result_file.type
|
||||
assert expected_file["tenant_id"] == result_file.tenant_id
|
||||
|
@@ -5,6 +5,7 @@ from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import ArrayStringVariable, StringVariable
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
@@ -63,10 +64,11 @@ def test_overwrite_string_variable():
|
||||
name="test_string_variable",
|
||||
value="the second value",
|
||||
)
|
||||
conversation_id = str(uuid.uuid4())
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
@@ -77,6 +79,9 @@ def test_overwrite_string_variable():
|
||||
input_variable,
|
||||
)
|
||||
|
||||
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
|
||||
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
|
||||
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
@@ -91,11 +96,20 @@ def test_overwrite_string_variable():
|
||||
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
|
||||
},
|
||||
},
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run:
|
||||
list(node.run())
|
||||
mock_run.assert_called_once()
|
||||
list(node.run())
|
||||
expected_var = StringVariable(
|
||||
id=conversation_variable.id,
|
||||
name=conversation_variable.name,
|
||||
description=conversation_variable.description,
|
||||
selector=conversation_variable.selector,
|
||||
value_type=conversation_variable.value_type,
|
||||
value=input_variable.value,
|
||||
)
|
||||
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
|
||||
mock_conv_var_updater.flush.assert_called_once()
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
@@ -148,9 +162,10 @@ def test_append_variable_to_array():
|
||||
name="test_string_variable",
|
||||
value="the second value",
|
||||
)
|
||||
conversation_id = str(uuid.uuid4())
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
@@ -160,6 +175,9 @@ def test_append_variable_to_array():
|
||||
input_variable,
|
||||
)
|
||||
|
||||
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
|
||||
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
|
||||
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
@@ -174,11 +192,22 @@ def test_append_variable_to_array():
|
||||
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
|
||||
},
|
||||
},
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run:
|
||||
list(node.run())
|
||||
mock_run.assert_called_once()
|
||||
list(node.run())
|
||||
expected_value = list(conversation_variable.value)
|
||||
expected_value.append(input_variable.value)
|
||||
expected_var = ArrayStringVariable(
|
||||
id=conversation_variable.id,
|
||||
name=conversation_variable.name,
|
||||
description=conversation_variable.description,
|
||||
selector=conversation_variable.selector,
|
||||
value_type=conversation_variable.value_type,
|
||||
value=expected_value,
|
||||
)
|
||||
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
|
||||
mock_conv_var_updater.flush.assert_called_once()
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
@@ -225,13 +254,17 @@ def test_clear_array():
|
||||
value=["the first value"],
|
||||
)
|
||||
|
||||
conversation_id = str(uuid.uuid4())
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
|
||||
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
|
||||
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
@@ -246,11 +279,20 @@ def test_clear_array():
|
||||
"input_variable_selector": [],
|
||||
},
|
||||
},
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run:
|
||||
list(node.run())
|
||||
mock_run.assert_called_once()
|
||||
list(node.run())
|
||||
expected_var = ArrayStringVariable(
|
||||
id=conversation_variable.id,
|
||||
name=conversation_variable.name,
|
||||
description=conversation_variable.description,
|
||||
selector=conversation_variable.selector,
|
||||
value_type=conversation_variable.value_type,
|
||||
value=[],
|
||||
)
|
||||
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
|
||||
mock_conv_var_updater.flush.assert_called_once()
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
|
@@ -1,8 +1,12 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import FileSegment, StringSegment
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from factories.variable_factory import build_segment, segment_to_variable
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -44,3 +48,38 @@ def test_use_long_selector(pool):
|
||||
result = pool.get(("node_1", "part_1", "part_2"))
|
||||
assert result is not None
|
||||
assert result.value == "test_value"
|
||||
|
||||
|
||||
class TestVariablePool:
|
||||
def test_constructor(self):
|
||||
pool = VariablePool()
|
||||
pool = VariablePool(
|
||||
variable_dictionary={},
|
||||
user_inputs={},
|
||||
system_variables={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
pool = VariablePool(
|
||||
user_inputs={"key": "value"},
|
||||
system_variables={SystemVariableKey.WORKFLOW_ID: "test_workflow_id"},
|
||||
environment_variables=[
|
||||
segment_to_variable(
|
||||
segment=build_segment(1),
|
||||
selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_var_1"],
|
||||
name="env_var_1",
|
||||
)
|
||||
],
|
||||
conversation_variables=[
|
||||
segment_to_variable(
|
||||
segment=build_segment("1"),
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var_1"],
|
||||
name="conv_var_1",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def test_constructor_with_invalid_system_variable_key(self):
|
||||
with pytest.raises(ValidationError):
|
||||
VariablePool(system_variables={"invalid_key": "value"}) # type: ignore
|
||||
|
@@ -1,22 +1,10 @@
|
||||
from core.variables import SecretVariable
|
||||
import dataclasses
|
||||
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.utils import variable_template_parser
|
||||
|
||||
|
||||
def test_extract_selectors_from_template():
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey("user_id"): "fake-user-id",
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[
|
||||
SecretVariable(name="secret_key", value="fake-secret-key"),
|
||||
],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(("node_id", "custom_query"), "fake-user-query")
|
||||
template = (
|
||||
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
|
||||
)
|
||||
@@ -26,3 +14,35 @@ def test_extract_selectors_from_template():
|
||||
VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]),
|
||||
VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]),
|
||||
]
|
||||
|
||||
|
||||
def test_invalid_references():
|
||||
@dataclasses.dataclass
|
||||
class TestCase:
|
||||
name: str
|
||||
template: str
|
||||
|
||||
cases = [
|
||||
TestCase(
|
||||
name="lack of closing brace",
|
||||
template="Hello, {{#sys.user_id#",
|
||||
),
|
||||
TestCase(
|
||||
name="lack of opening brace",
|
||||
template="Hello, #sys.user_id#}}",
|
||||
),
|
||||
TestCase(
|
||||
name="lack selector name",
|
||||
template="Hello, {{#sys#}}",
|
||||
),
|
||||
TestCase(
|
||||
name="empty node name part",
|
||||
template="Hello, {{#.user_id#}}",
|
||||
),
|
||||
]
|
||||
for idx, c in enumerate(cases, 1):
|
||||
fail_msg = f"Test case {c.name} failed, index={idx}"
|
||||
selectors = variable_template_parser.extract_selectors_from_template(c.template)
|
||||
assert selectors == [], fail_msg
|
||||
parser = variable_template_parser.VariableTemplateParser(c.template)
|
||||
assert parser.extract_variable_selectors() == [], fail_msg
|
||||
|
865
api/tests/unit_tests/factories/test_variable_factory.py
Normal file
865
api/tests/unit_tests/factories/test_variable_factory.py
Normal file
@@ -0,0 +1,865 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from hypothesis import given
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import (
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
ObjectSegment,
|
||||
SecretVariable,
|
||||
SegmentType,
|
||||
StringVariable,
|
||||
)
|
||||
from core.variables.exc import VariableError
|
||||
from core.variables.segments import (
|
||||
ArrayAnySegment,
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayStringSegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.variables.types import SegmentType
|
||||
from factories import variable_factory
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
|
||||
|
||||
def test_string_variable():
|
||||
test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"}
|
||||
result = variable_factory.build_conversation_variable_from_mapping(test_data)
|
||||
assert isinstance(result, StringVariable)
|
||||
|
||||
|
||||
def test_integer_variable():
|
||||
test_data = {"value_type": "number", "name": "test_int", "value": 42}
|
||||
result = variable_factory.build_conversation_variable_from_mapping(test_data)
|
||||
assert isinstance(result, IntegerVariable)
|
||||
|
||||
|
||||
def test_float_variable():
|
||||
test_data = {"value_type": "number", "name": "test_float", "value": 3.14}
|
||||
result = variable_factory.build_conversation_variable_from_mapping(test_data)
|
||||
assert isinstance(result, FloatVariable)
|
||||
|
||||
|
||||
def test_secret_variable():
|
||||
test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"}
|
||||
result = variable_factory.build_conversation_variable_from_mapping(test_data)
|
||||
assert isinstance(result, SecretVariable)
|
||||
|
||||
|
||||
def test_invalid_value_type():
|
||||
test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"}
|
||||
with pytest.raises(VariableError):
|
||||
variable_factory.build_conversation_variable_from_mapping(test_data)
|
||||
|
||||
|
||||
def test_build_a_blank_string():
|
||||
result = variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"value_type": "string",
|
||||
"name": "blank",
|
||||
"value": "",
|
||||
}
|
||||
)
|
||||
assert isinstance(result, StringVariable)
|
||||
assert result.value == ""
|
||||
|
||||
|
||||
def test_build_a_object_variable_with_none_value():
|
||||
var = variable_factory.build_segment(
|
||||
{
|
||||
"key1": None,
|
||||
}
|
||||
)
|
||||
assert isinstance(var, ObjectSegment)
|
||||
assert var.value["key1"] is None
|
||||
|
||||
|
||||
def test_object_variable():
|
||||
mapping = {
|
||||
"id": str(uuid4()),
|
||||
"value_type": "object",
|
||||
"name": "test_object",
|
||||
"description": "Description of the variable.",
|
||||
"value": {
|
||||
"key1": "text",
|
||||
"key2": 2,
|
||||
},
|
||||
}
|
||||
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ObjectSegment)
|
||||
assert isinstance(variable.value["key1"], str)
|
||||
assert isinstance(variable.value["key2"], int)
|
||||
|
||||
|
||||
def test_array_string_variable():
|
||||
mapping = {
|
||||
"id": str(uuid4()),
|
||||
"value_type": "array[string]",
|
||||
"name": "test_array",
|
||||
"description": "Description of the variable.",
|
||||
"value": [
|
||||
"text",
|
||||
"text",
|
||||
],
|
||||
}
|
||||
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ArrayStringVariable)
|
||||
assert isinstance(variable.value[0], str)
|
||||
assert isinstance(variable.value[1], str)
|
||||
|
||||
|
||||
def test_array_number_variable():
|
||||
mapping = {
|
||||
"id": str(uuid4()),
|
||||
"value_type": "array[number]",
|
||||
"name": "test_array",
|
||||
"description": "Description of the variable.",
|
||||
"value": [
|
||||
1,
|
||||
2.0,
|
||||
],
|
||||
}
|
||||
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ArrayNumberVariable)
|
||||
assert isinstance(variable.value[0], int)
|
||||
assert isinstance(variable.value[1], float)
|
||||
|
||||
|
||||
def test_array_object_variable():
|
||||
mapping = {
|
||||
"id": str(uuid4()),
|
||||
"value_type": "array[object]",
|
||||
"name": "test_array",
|
||||
"description": "Description of the variable.",
|
||||
"value": [
|
||||
{
|
||||
"key1": "text",
|
||||
"key2": 1,
|
||||
},
|
||||
{
|
||||
"key1": "text",
|
||||
"key2": 1,
|
||||
},
|
||||
],
|
||||
}
|
||||
variable = variable_factory.build_conversation_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ArrayObjectVariable)
|
||||
assert isinstance(variable.value[0], dict)
|
||||
assert isinstance(variable.value[1], dict)
|
||||
assert isinstance(variable.value[0]["key1"], str)
|
||||
assert isinstance(variable.value[0]["key2"], int)
|
||||
assert isinstance(variable.value[1]["key1"], str)
|
||||
assert isinstance(variable.value[1]["key2"], int)
|
||||
|
||||
|
||||
def test_variable_cannot_large_than_200_kb():
|
||||
with pytest.raises(VariableError):
|
||||
variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"id": str(uuid4()),
|
||||
"value_type": "string",
|
||||
"name": "test_text",
|
||||
"value": "a" * 1024 * 201,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_array_none_variable():
|
||||
var = variable_factory.build_segment([None, None, None, None])
|
||||
assert isinstance(var, ArrayAnySegment)
|
||||
assert var.value == [None, None, None, None]
|
||||
|
||||
|
||||
def test_build_segment_none_type():
|
||||
"""Test building NoneSegment from None value."""
|
||||
segment = variable_factory.build_segment(None)
|
||||
assert isinstance(segment, NoneSegment)
|
||||
assert segment.value is None
|
||||
assert segment.value_type == SegmentType.NONE
|
||||
|
||||
|
||||
def test_build_segment_none_type_properties():
|
||||
"""Test NoneSegment properties and methods."""
|
||||
segment = variable_factory.build_segment(None)
|
||||
assert segment.text == ""
|
||||
assert segment.log == ""
|
||||
assert segment.markdown == ""
|
||||
assert segment.to_object() is None
|
||||
|
||||
|
||||
def test_build_segment_array_file_single_file():
|
||||
"""Test building ArrayFileSegment from list with single file."""
|
||||
file = File(
|
||||
id="test_file_id",
|
||||
tenant_id="test_tenant_id",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://test.example.com/test-file.png",
|
||||
filename="test-file",
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
size=1000,
|
||||
)
|
||||
segment = variable_factory.build_segment([file])
|
||||
assert isinstance(segment, ArrayFileSegment)
|
||||
assert len(segment.value) == 1
|
||||
assert segment.value[0] == file
|
||||
assert segment.value_type == SegmentType.ARRAY_FILE
|
||||
|
||||
|
||||
def test_build_segment_array_file_multiple_files():
|
||||
"""Test building ArrayFileSegment from list with multiple files."""
|
||||
file1 = File(
|
||||
id="test_file_id_1",
|
||||
tenant_id="test_tenant_id",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://test.example.com/test-file1.png",
|
||||
filename="test-file1",
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
size=1000,
|
||||
)
|
||||
file2 = File(
|
||||
id="test_file_id_2",
|
||||
tenant_id="test_tenant_id",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="test_relation_id",
|
||||
filename="test-file2",
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
size=500,
|
||||
)
|
||||
segment = variable_factory.build_segment([file1, file2])
|
||||
assert isinstance(segment, ArrayFileSegment)
|
||||
assert len(segment.value) == 2
|
||||
assert segment.value[0] == file1
|
||||
assert segment.value[1] == file2
|
||||
assert segment.value_type == SegmentType.ARRAY_FILE
|
||||
|
||||
|
||||
def test_build_segment_array_file_empty_list():
|
||||
"""Test building ArrayFileSegment from empty list should create ArrayAnySegment."""
|
||||
segment = variable_factory.build_segment([])
|
||||
assert isinstance(segment, ArrayAnySegment)
|
||||
assert segment.value == []
|
||||
assert segment.value_type == SegmentType.ARRAY_ANY
|
||||
|
||||
|
||||
def test_build_segment_array_any_mixed_types():
|
||||
"""Test building ArrayAnySegment from list with mixed types."""
|
||||
mixed_values = ["string", 42, 3.14, {"key": "value"}, None]
|
||||
segment = variable_factory.build_segment(mixed_values)
|
||||
assert isinstance(segment, ArrayAnySegment)
|
||||
assert segment.value == mixed_values
|
||||
assert segment.value_type == SegmentType.ARRAY_ANY
|
||||
|
||||
|
||||
def test_build_segment_array_any_with_nested_arrays():
|
||||
"""Test building ArrayAnySegment from list containing arrays."""
|
||||
nested_values = [["nested", "array"], [1, 2, 3], "string"]
|
||||
segment = variable_factory.build_segment(nested_values)
|
||||
assert isinstance(segment, ArrayAnySegment)
|
||||
assert segment.value == nested_values
|
||||
assert segment.value_type == SegmentType.ARRAY_ANY
|
||||
|
||||
|
||||
def test_build_segment_array_any_mixed_with_files():
|
||||
"""Test building ArrayAnySegment from list with files and other types."""
|
||||
file = File(
|
||||
id="test_file_id",
|
||||
tenant_id="test_tenant_id",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://test.example.com/test-file.png",
|
||||
filename="test-file",
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
size=1000,
|
||||
)
|
||||
mixed_values = [file, "string", 42]
|
||||
segment = variable_factory.build_segment(mixed_values)
|
||||
assert isinstance(segment, ArrayAnySegment)
|
||||
assert segment.value == mixed_values
|
||||
assert segment.value_type == SegmentType.ARRAY_ANY
|
||||
|
||||
|
||||
def test_build_segment_array_any_all_none_values():
|
||||
"""Test building ArrayAnySegment from list with all None values."""
|
||||
none_values = [None, None, None]
|
||||
segment = variable_factory.build_segment(none_values)
|
||||
assert isinstance(segment, ArrayAnySegment)
|
||||
assert segment.value == none_values
|
||||
assert segment.value_type == SegmentType.ARRAY_ANY
|
||||
|
||||
|
||||
def test_build_segment_array_file_properties():
|
||||
"""Test ArrayFileSegment properties and methods."""
|
||||
file1 = File(
|
||||
id="test_file_id_1",
|
||||
tenant_id="test_tenant_id",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://test.example.com/test-file1.png",
|
||||
filename="test-file1",
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
size=1000,
|
||||
)
|
||||
file2 = File(
|
||||
id="test_file_id_2",
|
||||
tenant_id="test_tenant_id",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://test.example.com/test-file2.txt",
|
||||
filename="test-file2",
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
size=500,
|
||||
)
|
||||
segment = variable_factory.build_segment([file1, file2])
|
||||
|
||||
# Test properties
|
||||
assert segment.text == "" # ArrayFileSegment text property returns empty string
|
||||
assert segment.log == "" # ArrayFileSegment log property returns empty string
|
||||
assert segment.markdown == f"{file1.markdown}\n{file2.markdown}"
|
||||
assert segment.to_object() == [file1, file2]
|
||||
|
||||
|
||||
def test_build_segment_array_any_properties():
|
||||
"""Test ArrayAnySegment properties and methods."""
|
||||
mixed_values = ["string", 42, None]
|
||||
segment = variable_factory.build_segment(mixed_values)
|
||||
|
||||
# Test properties
|
||||
assert segment.text == str(mixed_values)
|
||||
assert segment.log == str(mixed_values)
|
||||
assert segment.markdown == "string\n42\nNone"
|
||||
assert segment.to_object() == mixed_values
|
||||
|
||||
|
||||
def test_build_segment_edge_cases():
|
||||
"""Test edge cases for build_segment function."""
|
||||
# Test with complex nested structures
|
||||
complex_structure = [{"nested": {"deep": [1, 2, 3]}}, [{"inner": "value"}], "mixed"]
|
||||
segment = variable_factory.build_segment(complex_structure)
|
||||
assert isinstance(segment, ArrayAnySegment)
|
||||
assert segment.value == complex_structure
|
||||
|
||||
# Test with single None in list
|
||||
single_none = [None]
|
||||
segment = variable_factory.build_segment(single_none)
|
||||
assert isinstance(segment, ArrayAnySegment)
|
||||
assert segment.value == single_none
|
||||
|
||||
|
||||
def test_build_segment_file_array_with_different_file_types():
|
||||
"""Test ArrayFileSegment with different file types."""
|
||||
image_file = File(
|
||||
id="image_id",
|
||||
tenant_id="test_tenant_id",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://test.example.com/image.png",
|
||||
filename="image",
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
size=1000,
|
||||
)
|
||||
|
||||
video_file = File(
|
||||
id="video_id",
|
||||
tenant_id="test_tenant_id",
|
||||
type=FileType.VIDEO,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="video_relation_id",
|
||||
filename="video",
|
||||
extension=".mp4",
|
||||
mime_type="video/mp4",
|
||||
size=5000,
|
||||
)
|
||||
|
||||
audio_file = File(
|
||||
id="audio_id",
|
||||
tenant_id="test_tenant_id",
|
||||
type=FileType.AUDIO,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="audio_relation_id",
|
||||
filename="audio",
|
||||
extension=".mp3",
|
||||
mime_type="audio/mpeg",
|
||||
size=3000,
|
||||
)
|
||||
|
||||
segment = variable_factory.build_segment([image_file, video_file, audio_file])
|
||||
assert isinstance(segment, ArrayFileSegment)
|
||||
assert len(segment.value) == 3
|
||||
assert segment.value[0].type == FileType.IMAGE
|
||||
assert segment.value[1].type == FileType.VIDEO
|
||||
assert segment.value[2].type == FileType.AUDIO
|
||||
|
||||
|
||||
@st.composite
|
||||
def _generate_file(draw) -> File:
|
||||
file_id = draw(st.text(min_size=1, max_size=10))
|
||||
tenant_id = draw(st.text(min_size=1, max_size=10))
|
||||
file_type, mime_type, extension = draw(
|
||||
st.sampled_from(
|
||||
[
|
||||
(FileType.IMAGE, "image/png", ".png"),
|
||||
(FileType.VIDEO, "video/mp4", ".mp4"),
|
||||
(FileType.DOCUMENT, "text/plain", ".txt"),
|
||||
(FileType.AUDIO, "audio/mpeg", ".mp3"),
|
||||
]
|
||||
)
|
||||
)
|
||||
filename = "test-file"
|
||||
size = draw(st.integers(min_value=0, max_value=1024 * 1024))
|
||||
|
||||
transfer_method = draw(st.sampled_from(list(FileTransferMethod)))
|
||||
if transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
url = "https://test.example.com/test-file"
|
||||
file = File(
|
||||
id="test_file_id",
|
||||
tenant_id="test_tenant_id",
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=url,
|
||||
related_id=None,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
)
|
||||
else:
|
||||
relation_id = draw(st.uuids(version=4))
|
||||
|
||||
file = File(
|
||||
id="test_file_id",
|
||||
tenant_id="test_tenant_id",
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
related_id=str(relation_id),
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
)
|
||||
return file
|
||||
|
||||
|
||||
def _scalar_value() -> st.SearchStrategy[int | float | str | File | None]:
|
||||
return st.one_of(
|
||||
st.none(),
|
||||
st.integers(),
|
||||
st.floats(),
|
||||
st.text(),
|
||||
_generate_file(),
|
||||
)
|
||||
|
||||
|
||||
@given(_scalar_value())
|
||||
def test_build_segment_and_extract_values_for_scalar_types(value):
|
||||
seg = variable_factory.build_segment(value)
|
||||
# nan == nan yields false, so we need to use `math.isnan` to check `seg.value` here.
|
||||
if isinstance(value, float) and math.isnan(value):
|
||||
assert math.isnan(seg.value)
|
||||
else:
|
||||
assert seg.value == value
|
||||
|
||||
|
||||
@given(st.lists(_scalar_value()))
|
||||
def test_build_segment_and_extract_values_for_array_types(values):
|
||||
seg = variable_factory.build_segment(values)
|
||||
assert seg.value == values
|
||||
|
||||
|
||||
def test_build_segment_type_for_scalar():
|
||||
@dataclass(frozen=True)
|
||||
class TestCase:
|
||||
value: int | float | str | File
|
||||
expected_type: SegmentType
|
||||
|
||||
file = File(
|
||||
id="test_file_id",
|
||||
tenant_id="test_tenant_id",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://test.example.com/test-file.png",
|
||||
filename="test-file",
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
size=1000,
|
||||
)
|
||||
cases = [
|
||||
TestCase(0, SegmentType.NUMBER),
|
||||
TestCase(0.0, SegmentType.NUMBER),
|
||||
TestCase("", SegmentType.STRING),
|
||||
TestCase(file, SegmentType.FILE),
|
||||
]
|
||||
|
||||
for idx, c in enumerate(cases, 1):
|
||||
segment = variable_factory.build_segment(c.value)
|
||||
assert segment.value_type == c.expected_type, f"test case {idx} failed."
|
||||
|
||||
|
||||
class TestBuildSegmentWithType:
|
||||
"""Test cases for build_segment_with_type function."""
|
||||
|
||||
def test_string_type(self):
|
||||
"""Test building a string segment with correct type."""
|
||||
result = build_segment_with_type(SegmentType.STRING, "hello")
|
||||
assert isinstance(result, StringSegment)
|
||||
assert result.value == "hello"
|
||||
assert result.value_type == SegmentType.STRING
|
||||
|
||||
def test_number_type_integer(self):
|
||||
"""Test building a number segment with integer value."""
|
||||
result = build_segment_with_type(SegmentType.NUMBER, 42)
|
||||
assert isinstance(result, IntegerSegment)
|
||||
assert result.value == 42
|
||||
assert result.value_type == SegmentType.NUMBER
|
||||
|
||||
def test_number_type_float(self):
|
||||
"""Test building a number segment with float value."""
|
||||
result = build_segment_with_type(SegmentType.NUMBER, 3.14)
|
||||
assert isinstance(result, FloatSegment)
|
||||
assert result.value == 3.14
|
||||
assert result.value_type == SegmentType.NUMBER
|
||||
|
||||
def test_object_type(self):
|
||||
"""Test building an object segment with correct type."""
|
||||
test_obj = {"key": "value", "nested": {"inner": 123}}
|
||||
result = build_segment_with_type(SegmentType.OBJECT, test_obj)
|
||||
assert isinstance(result, ObjectSegment)
|
||||
assert result.value == test_obj
|
||||
assert result.value_type == SegmentType.OBJECT
|
||||
|
||||
def test_file_type(self):
|
||||
"""Test building a file segment with correct type."""
|
||||
test_file = File(
|
||||
id="test_file_id",
|
||||
tenant_id="test_tenant_id",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://test.example.com/test-file.png",
|
||||
filename="test-file",
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
size=1000,
|
||||
storage_key="test_storage_key",
|
||||
)
|
||||
result = build_segment_with_type(SegmentType.FILE, test_file)
|
||||
assert isinstance(result, FileSegment)
|
||||
assert result.value == test_file
|
||||
assert result.value_type == SegmentType.FILE
|
||||
|
||||
def test_none_type(self):
|
||||
"""Test building a none segment with None value."""
|
||||
result = build_segment_with_type(SegmentType.NONE, None)
|
||||
assert isinstance(result, NoneSegment)
|
||||
assert result.value is None
|
||||
assert result.value_type == SegmentType.NONE
|
||||
|
||||
def test_empty_array_string(self):
|
||||
"""Test building an empty array[string] segment."""
|
||||
result = build_segment_with_type(SegmentType.ARRAY_STRING, [])
|
||||
assert isinstance(result, ArrayStringSegment)
|
||||
assert result.value == []
|
||||
assert result.value_type == SegmentType.ARRAY_STRING
|
||||
|
||||
def test_empty_array_number(self):
|
||||
"""Test building an empty array[number] segment."""
|
||||
result = build_segment_with_type(SegmentType.ARRAY_NUMBER, [])
|
||||
assert isinstance(result, ArrayNumberSegment)
|
||||
assert result.value == []
|
||||
assert result.value_type == SegmentType.ARRAY_NUMBER
|
||||
|
||||
def test_empty_array_object(self):
|
||||
"""Test building an empty array[object] segment."""
|
||||
result = build_segment_with_type(SegmentType.ARRAY_OBJECT, [])
|
||||
assert isinstance(result, ArrayObjectSegment)
|
||||
assert result.value == []
|
||||
assert result.value_type == SegmentType.ARRAY_OBJECT
|
||||
|
||||
def test_empty_array_file(self):
|
||||
"""Test building an empty array[file] segment."""
|
||||
result = build_segment_with_type(SegmentType.ARRAY_FILE, [])
|
||||
assert isinstance(result, ArrayFileSegment)
|
||||
assert result.value == []
|
||||
assert result.value_type == SegmentType.ARRAY_FILE
|
||||
|
||||
def test_empty_array_any(self):
|
||||
"""Test building an empty array[any] segment."""
|
||||
result = build_segment_with_type(SegmentType.ARRAY_ANY, [])
|
||||
assert isinstance(result, ArrayAnySegment)
|
||||
assert result.value == []
|
||||
assert result.value_type == SegmentType.ARRAY_ANY
|
||||
|
||||
def test_array_with_values(self):
|
||||
"""Test building array segments with actual values."""
|
||||
# Array of strings
|
||||
result = build_segment_with_type(SegmentType.ARRAY_STRING, ["hello", "world"])
|
||||
assert isinstance(result, ArrayStringSegment)
|
||||
assert result.value == ["hello", "world"]
|
||||
assert result.value_type == SegmentType.ARRAY_STRING
|
||||
|
||||
# Array of numbers
|
||||
result = build_segment_with_type(SegmentType.ARRAY_NUMBER, [1, 2, 3.14])
|
||||
assert isinstance(result, ArrayNumberSegment)
|
||||
assert result.value == [1, 2, 3.14]
|
||||
assert result.value_type == SegmentType.ARRAY_NUMBER
|
||||
|
||||
# Array of objects
|
||||
result = build_segment_with_type(SegmentType.ARRAY_OBJECT, [{"a": 1}, {"b": 2}])
|
||||
assert isinstance(result, ArrayObjectSegment)
|
||||
assert result.value == [{"a": 1}, {"b": 2}]
|
||||
assert result.value_type == SegmentType.ARRAY_OBJECT
|
||||
|
||||
def test_type_mismatch_string_to_number(self):
|
||||
"""Test type mismatch when expecting number but getting string."""
|
||||
with pytest.raises(TypeMismatchError) as exc_info:
|
||||
build_segment_with_type(SegmentType.NUMBER, "not_a_number")
|
||||
|
||||
assert "Type mismatch" in str(exc_info.value)
|
||||
assert "expected number" in str(exc_info.value)
|
||||
assert "str" in str(exc_info.value)
|
||||
|
||||
def test_type_mismatch_number_to_string(self):
|
||||
"""Test type mismatch when expecting string but getting number."""
|
||||
with pytest.raises(TypeMismatchError) as exc_info:
|
||||
build_segment_with_type(SegmentType.STRING, 123)
|
||||
|
||||
assert "Type mismatch" in str(exc_info.value)
|
||||
assert "expected string" in str(exc_info.value)
|
||||
assert "int" in str(exc_info.value)
|
||||
|
||||
def test_type_mismatch_none_to_string(self):
|
||||
"""Test type mismatch when expecting string but getting None."""
|
||||
with pytest.raises(TypeMismatchError) as exc_info:
|
||||
build_segment_with_type(SegmentType.STRING, None)
|
||||
|
||||
assert "Expected string, but got None" in str(exc_info.value)
|
||||
|
||||
def test_type_mismatch_empty_list_to_non_array(self):
|
||||
"""Test type mismatch when expecting non-array type but getting empty list."""
|
||||
with pytest.raises(TypeMismatchError) as exc_info:
|
||||
build_segment_with_type(SegmentType.STRING, [])
|
||||
|
||||
assert "Expected string, but got empty list" in str(exc_info.value)
|
||||
|
||||
def test_type_mismatch_object_to_array(self):
|
||||
"""Test type mismatch when expecting array but getting object."""
|
||||
with pytest.raises(TypeMismatchError) as exc_info:
|
||||
build_segment_with_type(SegmentType.ARRAY_STRING, {"key": "value"})
|
||||
|
||||
assert "Type mismatch" in str(exc_info.value)
|
||||
assert "expected array[string]" in str(exc_info.value)
|
||||
|
||||
def test_compatible_number_types(self):
|
||||
"""Test that int and float are both compatible with NUMBER type."""
|
||||
# Integer should work
|
||||
result_int = build_segment_with_type(SegmentType.NUMBER, 42)
|
||||
assert isinstance(result_int, IntegerSegment)
|
||||
assert result_int.value_type == SegmentType.NUMBER
|
||||
|
||||
# Float should work
|
||||
result_float = build_segment_with_type(SegmentType.NUMBER, 3.14)
|
||||
assert isinstance(result_float, FloatSegment)
|
||||
assert result_float.value_type == SegmentType.NUMBER
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("segment_type", "value", "expected_class"),
|
||||
[
|
||||
(SegmentType.STRING, "test", StringSegment),
|
||||
(SegmentType.NUMBER, 42, IntegerSegment),
|
||||
(SegmentType.NUMBER, 3.14, FloatSegment),
|
||||
(SegmentType.OBJECT, {}, ObjectSegment),
|
||||
(SegmentType.NONE, None, NoneSegment),
|
||||
(SegmentType.ARRAY_STRING, [], ArrayStringSegment),
|
||||
(SegmentType.ARRAY_NUMBER, [], ArrayNumberSegment),
|
||||
(SegmentType.ARRAY_OBJECT, [], ArrayObjectSegment),
|
||||
(SegmentType.ARRAY_ANY, [], ArrayAnySegment),
|
||||
],
|
||||
)
|
||||
def test_parametrized_valid_types(self, segment_type, value, expected_class):
|
||||
"""Parametrized test for valid type combinations."""
|
||||
result = build_segment_with_type(segment_type, value)
|
||||
assert isinstance(result, expected_class)
|
||||
assert result.value == value
|
||||
assert result.value_type == segment_type
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("segment_type", "value"),
|
||||
[
|
||||
(SegmentType.STRING, 123),
|
||||
(SegmentType.NUMBER, "not_a_number"),
|
||||
(SegmentType.OBJECT, "not_an_object"),
|
||||
(SegmentType.ARRAY_STRING, "not_an_array"),
|
||||
(SegmentType.STRING, None),
|
||||
(SegmentType.NUMBER, None),
|
||||
],
|
||||
)
|
||||
def test_parametrized_type_mismatches(self, segment_type, value):
|
||||
"""Parametrized test for type mismatches that should raise TypeMismatchError."""
|
||||
with pytest.raises(TypeMismatchError):
|
||||
build_segment_with_type(segment_type, value)
|
||||
|
||||
|
||||
# Test cases for ValueError scenarios in build_segment function
|
||||
class TestBuildSegmentValueErrors:
|
||||
"""Test cases for ValueError scenarios in the build_segment function."""
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ValueErrorTestCase:
|
||||
"""Test case data for ValueError scenarios."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
test_value: Any
|
||||
|
||||
def _get_test_cases(self):
|
||||
"""Get all test cases for ValueError scenarios."""
|
||||
|
||||
# Define inline classes for complex test cases
|
||||
class CustomType:
|
||||
pass
|
||||
|
||||
def unsupported_function():
|
||||
return "test"
|
||||
|
||||
def gen():
|
||||
yield 1
|
||||
yield 2
|
||||
|
||||
return [
|
||||
self.ValueErrorTestCase(
|
||||
name="unsupported_custom_type",
|
||||
description="custom class that doesn't match any supported type",
|
||||
test_value=CustomType(),
|
||||
),
|
||||
self.ValueErrorTestCase(
|
||||
name="unsupported_set_type",
|
||||
description="set (unsupported collection type)",
|
||||
test_value={1, 2, 3},
|
||||
),
|
||||
self.ValueErrorTestCase(
|
||||
name="unsupported_tuple_type", description="tuple (unsupported type)", test_value=(1, 2, 3)
|
||||
),
|
||||
self.ValueErrorTestCase(
|
||||
name="unsupported_bytes_type",
|
||||
description="bytes (unsupported type)",
|
||||
test_value=b"hello world",
|
||||
),
|
||||
self.ValueErrorTestCase(
|
||||
name="unsupported_function_type",
|
||||
description="function (unsupported type)",
|
||||
test_value=unsupported_function,
|
||||
),
|
||||
self.ValueErrorTestCase(
|
||||
name="unsupported_module_type", description="module (unsupported type)", test_value=math
|
||||
),
|
||||
self.ValueErrorTestCase(
|
||||
name="array_with_unsupported_element_types",
|
||||
description="array with unsupported element types",
|
||||
test_value=[CustomType()],
|
||||
),
|
||||
self.ValueErrorTestCase(
|
||||
name="mixed_array_with_unsupported_types",
|
||||
description="array with mix of supported and unsupported types",
|
||||
test_value=["valid_string", 42, CustomType()],
|
||||
),
|
||||
self.ValueErrorTestCase(
|
||||
name="nested_unsupported_types",
|
||||
description="nested structures containing unsupported types",
|
||||
test_value=[{"valid": "data"}, CustomType()],
|
||||
),
|
||||
self.ValueErrorTestCase(
|
||||
name="complex_number_type",
|
||||
description="complex number (unsupported type)",
|
||||
test_value=3 + 4j,
|
||||
),
|
||||
self.ValueErrorTestCase(
|
||||
name="range_type", description="range object (unsupported type)", test_value=range(10)
|
||||
),
|
||||
self.ValueErrorTestCase(
|
||||
name="generator_type",
|
||||
description="generator (unsupported type)",
|
||||
test_value=gen(),
|
||||
),
|
||||
self.ValueErrorTestCase(
|
||||
name="exception_message_contains_value",
|
||||
description="set to verify error message contains the actual unsupported value",
|
||||
test_value={1, 2, 3},
|
||||
),
|
||||
self.ValueErrorTestCase(
|
||||
name="array_with_mixed_unsupported_segment_types",
|
||||
description="array processing with unsupported segment types in match",
|
||||
test_value=[CustomType()],
|
||||
),
|
||||
self.ValueErrorTestCase(
|
||||
name="frozenset_type",
|
||||
description="frozenset (unsupported type)",
|
||||
test_value=frozenset([1, 2, 3]),
|
||||
),
|
||||
self.ValueErrorTestCase(
|
||||
name="memoryview_type",
|
||||
description="memoryview (unsupported type)",
|
||||
test_value=memoryview(b"hello"),
|
||||
),
|
||||
self.ValueErrorTestCase(
|
||||
name="slice_type", description="slice object (unsupported type)", test_value=slice(1, 10, 2)
|
||||
),
|
||||
self.ValueErrorTestCase(name="type_object", description="type object (unsupported type)", test_value=type),
|
||||
self.ValueErrorTestCase(
|
||||
name="generic_object", description="generic object (unsupported type)", test_value=object()
|
||||
),
|
||||
]
|
||||
|
||||
def test_build_segment_unsupported_types(self):
|
||||
"""Table-driven test for all ValueError scenarios in build_segment function."""
|
||||
test_cases = self._get_test_cases()
|
||||
|
||||
for index, test_case in enumerate(test_cases, 1):
|
||||
# Use test value directly
|
||||
test_value = test_case.test_value
|
||||
|
||||
with pytest.raises(ValueError) as exc_info: # noqa: PT012
|
||||
segment = variable_factory.build_segment(test_value)
|
||||
pytest.fail(f"Test case {index} ({test_case.name}) should raise ValueError but not, result={segment}")
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "not supported value" in error_message, (
|
||||
f"Test case {index} ({test_case.name}): Expected 'not supported value' in error message, "
|
||||
f"but got: {error_message}"
|
||||
)
|
||||
|
||||
def test_build_segment_boolean_type_note(self):
|
||||
"""Note: Boolean values are actually handled as integers in Python, so they don't raise ValueError."""
|
||||
# Boolean values in Python are subclasses of int, so they get processed as integers
|
||||
# True becomes IntegerSegment(value=1) and False becomes IntegerSegment(value=0)
|
||||
true_segment = variable_factory.build_segment(True)
|
||||
false_segment = variable_factory.build_segment(False)
|
||||
|
||||
# Verify they are processed as integers, not as errors
|
||||
assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1"
|
||||
assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0"
|
||||
assert true_segment.value_type == SegmentType.NUMBER
|
||||
assert false_segment.value_type == SegmentType.NUMBER
|
20
api/tests/unit_tests/libs/test_datetime_utils.py
Normal file
20
api/tests/unit_tests/libs/test_datetime_utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import datetime
|
||||
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
def test_naive_utc_now(monkeypatch):
|
||||
tz_aware_utc_now = datetime.datetime.now(tz=datetime.UTC)
|
||||
|
||||
def _now_func(tz: datetime.timezone | None) -> datetime.datetime:
|
||||
return tz_aware_utc_now.astimezone(tz)
|
||||
|
||||
monkeypatch.setattr("libs.datetime_utils._now_func", _now_func)
|
||||
|
||||
naive_datetime = naive_utc_now()
|
||||
|
||||
assert naive_datetime.tzinfo is None
|
||||
assert naive_datetime.date() == tz_aware_utc_now.date()
|
||||
naive_time = naive_datetime.time()
|
||||
utc_time = tz_aware_utc_now.time()
|
||||
assert naive_time == utc_time
|
0
api/tests/unit_tests/models/__init__.py
Normal file
0
api/tests/unit_tests/models/__init__.py
Normal file
@@ -1,10 +1,15 @@
|
||||
import dataclasses
|
||||
import json
|
||||
from unittest import mock
|
||||
from uuid import uuid4
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.file.models import File
|
||||
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionModel
|
||||
from core.variables.segments import IntegerSegment, Segment
|
||||
from factories.variable_factory import build_segment
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
|
||||
|
||||
|
||||
def test_environment_variables():
|
||||
@@ -163,3 +168,147 @@ class TestWorkflowNodeExecution:
|
||||
original = {"a": 1, "b": ["2"]}
|
||||
node_exec.execution_metadata = json.dumps(original)
|
||||
assert node_exec.execution_metadata_dict == original
|
||||
|
||||
|
||||
class TestIsSystemVariableEditable:
|
||||
def test_is_system_variable(self):
|
||||
cases = [
|
||||
("query", True),
|
||||
("files", True),
|
||||
("dialogue_count", False),
|
||||
("conversation_id", False),
|
||||
("user_id", False),
|
||||
("app_id", False),
|
||||
("workflow_id", False),
|
||||
("workflow_run_id", False),
|
||||
]
|
||||
for name, editable in cases:
|
||||
assert editable == is_system_variable_editable(name)
|
||||
|
||||
assert is_system_variable_editable("invalid_or_new_system_variable") == False
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableGetValue:
|
||||
def test_get_value_by_case(self):
|
||||
@dataclasses.dataclass
|
||||
class TestCase:
|
||||
name: str
|
||||
value: Segment
|
||||
|
||||
tenant_id = "test_tenant_id"
|
||||
|
||||
test_file = File(
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/example.jpg",
|
||||
filename="example.jpg",
|
||||
extension=".jpg",
|
||||
mime_type="image/jpeg",
|
||||
size=100,
|
||||
)
|
||||
cases: list[TestCase] = [
|
||||
TestCase(
|
||||
name="number/int",
|
||||
value=build_segment(1),
|
||||
),
|
||||
TestCase(
|
||||
name="number/float",
|
||||
value=build_segment(1.0),
|
||||
),
|
||||
TestCase(
|
||||
name="string",
|
||||
value=build_segment("a"),
|
||||
),
|
||||
TestCase(
|
||||
name="object",
|
||||
value=build_segment({}),
|
||||
),
|
||||
TestCase(
|
||||
name="file",
|
||||
value=build_segment(test_file),
|
||||
),
|
||||
TestCase(
|
||||
name="array[any]",
|
||||
value=build_segment([1, "a"]),
|
||||
),
|
||||
TestCase(
|
||||
name="array[string]",
|
||||
value=build_segment(["a", "b"]),
|
||||
),
|
||||
TestCase(
|
||||
name="array[number]/int",
|
||||
value=build_segment([1, 2]),
|
||||
),
|
||||
TestCase(
|
||||
name="array[number]/float",
|
||||
value=build_segment([1.0, 2.0]),
|
||||
),
|
||||
TestCase(
|
||||
name="array[number]/mixed",
|
||||
value=build_segment([1, 2.0]),
|
||||
),
|
||||
TestCase(
|
||||
name="array[object]",
|
||||
value=build_segment([{}, {"a": 1}]),
|
||||
),
|
||||
TestCase(
|
||||
name="none",
|
||||
value=build_segment(None),
|
||||
),
|
||||
]
|
||||
|
||||
for idx, c in enumerate(cases, 1):
|
||||
fail_msg = f"test case {c.name} failed, index={idx}"
|
||||
draft_var = WorkflowDraftVariable()
|
||||
draft_var.set_value(c.value)
|
||||
assert c.value == draft_var.get_value(), fail_msg
|
||||
|
||||
def test_file_variable_preserves_all_fields(self):
|
||||
"""Test that File type variables preserve all fields during encoding/decoding."""
|
||||
tenant_id = "test_tenant_id"
|
||||
|
||||
# Create a File with specific field values
|
||||
test_file = File(
|
||||
id="test_file_id",
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/test.jpg",
|
||||
filename="test.jpg",
|
||||
extension=".jpg",
|
||||
mime_type="image/jpeg",
|
||||
size=12345, # Specific size to test preservation
|
||||
storage_key="test_storage_key",
|
||||
)
|
||||
|
||||
# Create a FileSegment and WorkflowDraftVariable
|
||||
file_segment = build_segment(test_file)
|
||||
draft_var = WorkflowDraftVariable()
|
||||
draft_var.set_value(file_segment)
|
||||
|
||||
# Retrieve the value and verify all fields are preserved
|
||||
retrieved_segment = draft_var.get_value()
|
||||
retrieved_file = retrieved_segment.value
|
||||
|
||||
# Verify all important fields are preserved
|
||||
assert retrieved_file.id == test_file.id
|
||||
assert retrieved_file.tenant_id == test_file.tenant_id
|
||||
assert retrieved_file.type == test_file.type
|
||||
assert retrieved_file.transfer_method == test_file.transfer_method
|
||||
assert retrieved_file.remote_url == test_file.remote_url
|
||||
assert retrieved_file.filename == test_file.filename
|
||||
assert retrieved_file.extension == test_file.extension
|
||||
assert retrieved_file.mime_type == test_file.mime_type
|
||||
assert retrieved_file.size == test_file.size # This was the main issue being fixed
|
||||
# Note: storage_key is not serialized in model_dump() so it won't be preserved
|
||||
|
||||
# Verify the segments have the same type and the important fields match
|
||||
assert file_segment.value_type == retrieved_segment.value_type
|
||||
|
||||
def test_get_and_set_value(self):
|
||||
draft_var = WorkflowDraftVariable()
|
||||
int_var = IntegerSegment(value=1)
|
||||
draft_var.set_value(int_var)
|
||||
value = draft_var.get_value()
|
||||
assert value == int_var
|
||||
|
@@ -0,0 +1,222 @@
|
||||
import dataclasses
|
||||
import secrets
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.nodes import NodeType
|
||||
from models.enums import DraftVariableType
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVariableSaver,
|
||||
VariableResetError,
|
||||
WorkflowDraftVariableService,
|
||||
)
|
||||
|
||||
|
||||
class TestDraftVariableSaver:
|
||||
def _get_test_app_id(self):
|
||||
suffix = secrets.token_hex(6)
|
||||
return f"test_app_id_{suffix}"
|
||||
|
||||
def test__should_variable_be_visible(self):
|
||||
mock_session = mock.MagicMock(spec=Session)
|
||||
test_app_id = self._get_test_app_id()
|
||||
saver = DraftVariableSaver(
|
||||
session=mock_session,
|
||||
app_id=test_app_id,
|
||||
node_id="test_node_id",
|
||||
node_type=NodeType.START,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
node_execution_id="test_execution_id",
|
||||
)
|
||||
assert saver._should_variable_be_visible("123_456", NodeType.IF_ELSE, "output") == False
|
||||
assert saver._should_variable_be_visible("123", NodeType.START, "output") == True
|
||||
|
||||
def test__normalize_variable_for_start_node(self):
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TestCase:
|
||||
name: str
|
||||
input_node_id: str
|
||||
input_name: str
|
||||
expected_node_id: str
|
||||
expected_name: str
|
||||
|
||||
_NODE_ID = "1747228642872"
|
||||
cases = [
|
||||
TestCase(
|
||||
name="name with `sys.` prefix should return the system node_id",
|
||||
input_node_id=_NODE_ID,
|
||||
input_name="sys.workflow_id",
|
||||
expected_node_id=SYSTEM_VARIABLE_NODE_ID,
|
||||
expected_name="workflow_id",
|
||||
),
|
||||
TestCase(
|
||||
name="name without `sys.` prefix should return the original input node_id",
|
||||
input_node_id=_NODE_ID,
|
||||
input_name="start_input",
|
||||
expected_node_id=_NODE_ID,
|
||||
expected_name="start_input",
|
||||
),
|
||||
TestCase(
|
||||
name="dummy_variable should return the original input node_id",
|
||||
input_node_id=_NODE_ID,
|
||||
input_name="__dummy__",
|
||||
expected_node_id=_NODE_ID,
|
||||
expected_name="__dummy__",
|
||||
),
|
||||
]
|
||||
|
||||
mock_session = mock.MagicMock(spec=Session)
|
||||
test_app_id = self._get_test_app_id()
|
||||
saver = DraftVariableSaver(
|
||||
session=mock_session,
|
||||
app_id=test_app_id,
|
||||
node_id=_NODE_ID,
|
||||
node_type=NodeType.START,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
node_execution_id="test_execution_id",
|
||||
)
|
||||
for idx, c in enumerate(cases, 1):
|
||||
fail_msg = f"Test case {c.name} failed, index={idx}"
|
||||
node_id, name = saver._normalize_variable_for_start_node(c.input_name)
|
||||
assert node_id == c.expected_node_id, fail_msg
|
||||
assert name == c.expected_name, fail_msg
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableService:
|
||||
def _get_test_app_id(self):
|
||||
suffix = secrets.token_hex(6)
|
||||
return f"test_app_id_{suffix}"
|
||||
|
||||
def test_reset_conversation_variable(self):
|
||||
"""Test resetting a conversation variable"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
mock_workflow = Mock(spec=Workflow)
|
||||
mock_workflow.app_id = self._get_test_app_id()
|
||||
|
||||
# Create mock variable
|
||||
mock_variable = Mock(spec=WorkflowDraftVariable)
|
||||
mock_variable.get_variable_type.return_value = DraftVariableType.CONVERSATION
|
||||
mock_variable.id = "var-id"
|
||||
mock_variable.name = "test_var"
|
||||
|
||||
# Mock the _reset_conv_var method
|
||||
expected_result = Mock(spec=WorkflowDraftVariable)
|
||||
with patch.object(service, "_reset_conv_var", return_value=expected_result) as mock_reset_conv:
|
||||
result = service.reset_variable(mock_workflow, mock_variable)
|
||||
|
||||
mock_reset_conv.assert_called_once_with(mock_workflow, mock_variable)
|
||||
assert result == expected_result
|
||||
|
||||
def test_reset_node_variable_with_no_execution_id(self):
|
||||
"""Test resetting a node variable with no execution ID - should delete variable"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
mock_workflow = Mock(spec=Workflow)
|
||||
mock_workflow.app_id = self._get_test_app_id()
|
||||
|
||||
# Create mock variable with no execution ID
|
||||
mock_variable = Mock(spec=WorkflowDraftVariable)
|
||||
mock_variable.get_variable_type.return_value = DraftVariableType.NODE
|
||||
mock_variable.node_execution_id = None
|
||||
mock_variable.id = "var-id"
|
||||
mock_variable.name = "test_var"
|
||||
|
||||
result = service._reset_node_var(mock_workflow, mock_variable)
|
||||
|
||||
# Should delete the variable and return None
|
||||
mock_session.delete.assert_called_once_with(instance=mock_variable)
|
||||
mock_session.flush.assert_called_once()
|
||||
assert result is None
|
||||
|
||||
def test_reset_node_variable_with_missing_execution_record(self):
|
||||
"""Test resetting a node variable when execution record doesn't exist"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
mock_workflow = Mock(spec=Workflow)
|
||||
mock_workflow.app_id = self._get_test_app_id()
|
||||
|
||||
# Create mock variable with execution ID
|
||||
mock_variable = Mock(spec=WorkflowDraftVariable)
|
||||
mock_variable.get_variable_type.return_value = DraftVariableType.NODE
|
||||
mock_variable.node_execution_id = "exec-id"
|
||||
mock_variable.id = "var-id"
|
||||
mock_variable.name = "test_var"
|
||||
|
||||
# Mock session.scalars to return None (no execution record found)
|
||||
mock_scalars = Mock()
|
||||
mock_scalars.first.return_value = None
|
||||
mock_session.scalars.return_value = mock_scalars
|
||||
|
||||
result = service._reset_node_var(mock_workflow, mock_variable)
|
||||
|
||||
# Should delete the variable and return None
|
||||
mock_session.delete.assert_called_once_with(instance=mock_variable)
|
||||
mock_session.flush.assert_called_once()
|
||||
assert result is None
|
||||
|
||||
def test_reset_node_variable_with_valid_execution_record(self):
|
||||
"""Test resetting a node variable with valid execution record - should restore from execution"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
mock_workflow = Mock(spec=Workflow)
|
||||
mock_workflow.app_id = self._get_test_app_id()
|
||||
|
||||
# Create mock variable with execution ID
|
||||
mock_variable = Mock(spec=WorkflowDraftVariable)
|
||||
mock_variable.get_variable_type.return_value = DraftVariableType.NODE
|
||||
mock_variable.node_execution_id = "exec-id"
|
||||
mock_variable.id = "var-id"
|
||||
mock_variable.name = "test_var"
|
||||
mock_variable.node_id = "node-id"
|
||||
mock_variable.value_type = SegmentType.STRING
|
||||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.process_data_dict = {"test_var": "process_value"}
|
||||
mock_execution.outputs_dict = {"test_var": "output_value"}
|
||||
|
||||
# Mock session.scalars to return the execution record
|
||||
mock_scalars = Mock()
|
||||
mock_scalars.first.return_value = mock_execution
|
||||
mock_session.scalars.return_value = mock_scalars
|
||||
|
||||
# Mock workflow methods
|
||||
mock_node_config = {"type": "test_node"}
|
||||
mock_workflow.get_node_config_by_id.return_value = mock_node_config
|
||||
mock_workflow.get_node_type_from_node_config.return_value = NodeType.LLM
|
||||
|
||||
result = service._reset_node_var(mock_workflow, mock_variable)
|
||||
|
||||
# Verify variable.set_value was called with the correct value
|
||||
mock_variable.set_value.assert_called_once()
|
||||
# Verify last_edited_at was reset
|
||||
assert mock_variable.last_edited_at is None
|
||||
# Verify session.flush was called
|
||||
mock_session.flush.assert_called()
|
||||
|
||||
# Should return the updated variable
|
||||
assert result == mock_variable
|
||||
|
||||
def test_reset_system_variable_raises_error(self):
|
||||
"""Test that resetting a system variable raises an error"""
|
||||
mock_session = Mock(spec=Session)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
mock_workflow = Mock(spec=Workflow)
|
||||
mock_workflow.app_id = self._get_test_app_id()
|
||||
|
||||
mock_variable = Mock(spec=WorkflowDraftVariable)
|
||||
mock_variable.get_variable_type.return_value = DraftVariableType.SYS # Not a valid enum value for this test
|
||||
mock_variable.id = "var-id"
|
||||
|
||||
with pytest.raises(VariableResetError) as exc_info:
|
||||
service.reset_variable(mock_workflow, mock_variable)
|
||||
assert "cannot reset system variable" in str(exc_info.value)
|
||||
assert "variable_id=var-id" in str(exc_info.value)
|
Reference in New Issue
Block a user