From 2c1ab4879f1cdfcbcb8a96d3f11d6bf5e442e8e1 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Wed, 16 Jul 2025 12:31:37 +0800 Subject: [PATCH] refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025) refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025) This PR addresses serialization issues in the VariablePool model by separating the `value_type` tags for `IntegerSegment`/`FloatSegment` and `IntegerVariable`/`FloatVariable`. Previously, both Integer and Float types shared the same `SegmentType.NUMBER` tag, causing conflicts during serialization. Key changes: - Introduce distinct `value_type` tags for Integer and Float segments/variables - Add `VariableUnion` and `SegmentUnion` types for proper type discrimination - Leverage Pydantic's discriminated union feature for seamless serialization/deserialization - Enable accurate serialization of data structures containing these types Closes #22024. --- .../console/app/workflow_draft_variable.py | 11 +- api/core/app/apps/advanced_chat/app_runner.py | 29 +- .../advanced_chat/generate_task_pipeline.py | 22 +- api/core/app/apps/workflow/app_runner.py | 17 +- .../apps/workflow/generate_task_pipeline.py | 16 +- api/core/app/apps/workflow_app_runner.py | 5 +- api/core/prompt/advanced_prompt_transform.py | 2 +- api/core/variables/segments.py | 56 ++- api/core/variables/types.py | 156 ++++++- api/core/variables/variables.py | 34 +- api/core/workflow/entities/variable_pool.py | 63 ++- .../entities/graph_runtime_state.py | 6 +- api/core/workflow/nodes/loop/entities.py | 24 +- api/core/workflow/nodes/loop/loop_node.py | 34 +- api/core/workflow/nodes/start/start_node.py | 2 +- .../nodes/variable_assigner/v1/node.py | 5 + .../nodes/variable_assigner/v2/constants.py | 1 + .../nodes/variable_assigner/v2/helpers.py | 14 +- api/core/workflow/system_variable.py | 89 ++++ api/core/workflow/workflow_cycle_manager.py | 18 +- api/core/workflow/workflow_entry.py | 3 +- api/factories/variable_factory.py | 67 ++- api/fields/_value_type_serializer.py | 15 + api/fields/conversation_variable_fields.py | 4 +- api/fields/workflow_fields.py | 13 +- api/models/workflow.py | 13 +- api/services/workflow_service.py | 43 +- .../workflow/nodes/test_code.py | 4 +- .../workflow/nodes/test_http.py | 4 +- .../workflow/nodes/test_llm.py | 16 +- .../nodes/test_parameter_extractor.py | 11 +- .../workflow/nodes/test_template_transform.py | 4 +- .../workflow/nodes/test_tool.py | 4 +- .../unit_tests/core/variables/test_segment.py | 345 ++++++++++++++- .../core/variables/test_segment_type.py | 60 +++ .../core/variables/test_variables.py | 3 +- .../entities/test_graph_runtime_state.py | 146 +++++++ .../entities/test_node_run_state.py | 401 ++++++++++++++++++ .../graph_engine/test_graph_engine.py | 47 +- .../core/workflow/nodes/answer/test_answer.py | 4 +- .../answer/test_answer_stream_processor.py | 14 +- .../test_http_request_executor.py | 23 +- .../http_request/test_http_request_node.py | 7 +- .../nodes/iteration/test_iteration.py | 50 +-- .../core/workflow/nodes/llm/test_node.py | 13 +- .../core/workflow/nodes/test_answer.py | 4 +- .../workflow/nodes/test_continue_on_error.py | 14 +- .../core/workflow/nodes/test_if_else.py | 8 +- .../workflow/nodes/tool/test_tool_node.py | 3 +- .../v1/test_variable_assigner_v1.py | 8 +- .../v2/test_variable_assigner_v2.py | 10 +- .../core/workflow/test_system_variable.py | 251 +++++++++++ .../core/workflow/test_variable_pool.py | 373 +++++++++++++++- .../workflow/test_workflow_cycle_manager.py | 18 +- .../workflow/utils/test_variable_utils.py | 12 +- .../factories/test_variable_factory.py | 24 +- .../components/base/chat/chat/question.tsx | 4 +- .../nodes/agent/components/tool-icon.tsx | 6 +- 58 files changed, 2325 insertions(+), 328 deletions(-) create mode 100644 api/core/workflow/system_variable.py create mode 100644 api/fields/_value_type_serializer.py create mode 100644 api/tests/unit_tests/core/variables/test_segment_type.py create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/entities/test_node_run_state.py create mode 100644 api/tests/unit_tests/core/workflow/test_system_variable.py diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 00d6fa3cb..ba93f8275 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -68,13 +68,18 @@ def _create_pagination_parser(): return parser +def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str: + value_type = workflow_draft_var.value_type + return value_type.exposed_type().value + + _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = { "id": fields.String, "type": fields.String(attribute=lambda model: model.get_variable_type()), "name": fields.String, "description": fields.String, "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), - "value_type": fields.String, + "value_type": fields.String(attribute=_serialize_variable_type), "edited": fields.Boolean(attribute=lambda model: model.edited), "visible": fields.Boolean, } @@ -90,7 +95,7 @@ _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = { "name": fields.String, "description": fields.String, "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), - "value_type": fields.String, + "value_type": fields.String(attribute=_serialize_variable_type), "edited": fields.Boolean(attribute=lambda model: model.edited), "visible": fields.Boolean, } @@ -396,7 +401,7 @@ class EnvironmentVariableCollectionApi(Resource): "name": v.name, "description": v.description, "selector": v.selector, - "value_type": v.value_type.value, + "value_type": v.value_type.exposed_type().value, "value": v.value, # Do not track edited for env vars. "edited": False, diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 840a3c9d3..af15324f4 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -16,9 +16,10 @@ from core.app.entities.queue_entities import ( QueueTextChunkEvent, ) from core.moderation.base import ModerationError +from core.variables.variables import VariableUnion from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey +from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db @@ -64,7 +65,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): if not workflow: raise ValueError("Workflow not initialized") - user_id = None + user_id: str | None = None if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() if end_user: @@ -136,23 +137,25 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): session.commit() # Create a variable pool. - system_inputs = { - SystemVariableKey.QUERY: query, - SystemVariableKey.FILES: files, - SystemVariableKey.CONVERSATION_ID: self.conversation.id, - SystemVariableKey.USER_ID: user_id, - SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count, - SystemVariableKey.APP_ID: app_config.app_id, - SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, - SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_run_id, - } + system_inputs = SystemVariable( + query=query, + files=files, + conversation_id=self.conversation.id, + user_id=user_id, + dialogue_count=self._dialogue_count, + app_id=app_config.app_id, + workflow_id=app_config.workflow_id, + workflow_execution_id=self.application_generate_entity.workflow_run_id, + ) # init variable pool variable_pool = VariablePool( system_variables=system_inputs, user_inputs=inputs, environment_variables=workflow.environment_variables, - conversation_variables=conversation_variables, + # Based on the definition of `VariableUnion`, + # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. + conversation_variables=cast(list[VariableUnion], conversation_variables), ) # init graph diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 4c52fc3e8..1dc9796d5 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -61,12 +61,12 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_runtime.entities.llm_entities import LLMUsage from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes import NodeType from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.system_variable import SystemVariable from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager from events.message_event import message_was_created from extensions.ext_database import db @@ -116,16 +116,16 @@ class AdvancedChatAppGenerateTaskPipeline: self._workflow_cycle_manager = WorkflowCycleManager( application_generate_entity=application_generate_entity, - workflow_system_variables={ - SystemVariableKey.QUERY: message.query, - SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.CONVERSATION_ID: conversation.id, - SystemVariableKey.USER_ID: user_session_id, - SystemVariableKey.DIALOGUE_COUNT: dialogue_count, - SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, - SystemVariableKey.WORKFLOW_ID: workflow.id, - SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_run_id, - }, + workflow_system_variables=SystemVariable( + query=message.query, + files=application_generate_entity.files, + conversation_id=conversation.id, + user_id=user_session_id, + dialogue_count=dialogue_count, + app_id=application_generate_entity.app_config.app_id, + workflow_id=workflow.id, + workflow_execution_id=application_generate_entity.workflow_run_id, + ), workflow_info=CycleManagerWorkflowInfo( workflow_id=workflow.id, workflow_type=WorkflowType(workflow.type), diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 07aeb57fa..3a66ffa57 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -11,7 +11,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey +from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db @@ -95,13 +95,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): files = self.application_generate_entity.files # Create a variable pool. - system_inputs = { - SystemVariableKey.FILES: files, - SystemVariableKey.USER_ID: user_id, - SystemVariableKey.APP_ID: app_config.app_id, - SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, - SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id, - } + + system_inputs = SystemVariable( + files=files, + user_id=user_id, + app_id=app_config.app_id, + workflow_id=app_config.workflow_id, + workflow_execution_id=self.application_generate_entity.workflow_execution_id, + ) variable_pool = VariablePool( system_variables=system_inputs, diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index c6b326d8a..7adc03e9c 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -54,10 +54,10 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType -from core.workflow.enums import SystemVariableKey from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.system_variable import SystemVariable from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager from extensions.ext_database import db from models.account import Account @@ -107,13 +107,13 @@ class WorkflowAppGenerateTaskPipeline: self._workflow_cycle_manager = WorkflowCycleManager( application_generate_entity=application_generate_entity, - workflow_system_variables={ - SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.USER_ID: user_session_id, - SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, - SystemVariableKey.WORKFLOW_ID: workflow.id, - SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_execution_id, - }, + workflow_system_variables=SystemVariable( + files=application_generate_entity.files, + user_id=user_session_id, + app_id=application_generate_entity.app_config.app_id, + workflow_id=workflow.id, + workflow_execution_id=application_generate_entity.workflow_execution_id, + ), workflow_info=CycleManagerWorkflowInfo( workflow_id=workflow.id, workflow_type=WorkflowType(workflow.type), diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 17b9ac582..2f4d234ec 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -62,6 +62,7 @@ from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes import NodeType from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db @@ -166,7 +167,7 @@ class WorkflowBasedAppRunner(AppRunner): # init variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, environment_variables=workflow.environment_variables, ) @@ -263,7 +264,7 @@ class WorkflowBasedAppRunner(AppRunner): # init variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, environment_variables=workflow.environment_variables, ) diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 25964ae06..0f0fe65f2 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -158,7 +158,7 @@ class AdvancedPromptTransform(PromptTransform): if prompt_item.edition_type == "basic" or not prompt_item.edition_type: if self.with_variable_tmpl: - vp = VariablePool() + vp = VariablePool.empty() for k, v in inputs.items(): if k.startswith("#"): vp.add(k[1:-1].split("."), v) diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 6cf09e037..13274f4e0 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -1,9 +1,9 @@ import json import sys from collections.abc import Mapping, Sequence -from typing import Any +from typing import Annotated, Any, TypeAlias -from pydantic import BaseModel, ConfigDict, field_validator +from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator from core.file import File @@ -11,6 +11,11 @@ from .types import SegmentType class Segment(BaseModel): + """Segment is runtime type used during the execution of workflow. + + Note: this class is abstract, you should use subclasses of this class instead. + """ + model_config = ConfigDict(frozen=True) value_type: SegmentType @@ -73,7 +78,7 @@ class StringSegment(Segment): class FloatSegment(Segment): - value_type: SegmentType = SegmentType.NUMBER + value_type: SegmentType = SegmentType.FLOAT value: float # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. # The following tests cannot pass. @@ -92,7 +97,7 @@ class FloatSegment(Segment): class IntegerSegment(Segment): - value_type: SegmentType = SegmentType.NUMBER + value_type: SegmentType = SegmentType.INTEGER value: int @@ -181,3 +186,46 @@ class ArrayFileSegment(ArraySegment): @property def text(self) -> str: return "" + + +def get_segment_discriminator(v: Any) -> SegmentType | None: + if isinstance(v, Segment): + return v.value_type + elif isinstance(v, dict): + value_type = v.get("value_type") + if value_type is None: + return None + try: + seg_type = SegmentType(value_type) + except ValueError: + return None + return seg_type + else: + # return None if the discriminator value isn't found + return None + + +# The `SegmentUnion`` type is used to enable serialization and deserialization with Pydantic. +# Use `Segment` for type hinting when serialization is not required. +# +# Note: +# - All variants in `SegmentUnion` must inherit from the `Segment` class. +# - The union must include all non-abstract subclasses of `Segment`, except: +# - `SegmentGroup`, which is not added to the variable pool. +# - `Variable` and its subclasses, which are handled by `VariableUnion`. +SegmentUnion: TypeAlias = Annotated[ + ( + Annotated[NoneSegment, Tag(SegmentType.NONE)] + | Annotated[StringSegment, Tag(SegmentType.STRING)] + | Annotated[FloatSegment, Tag(SegmentType.FLOAT)] + | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)] + | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)] + | Annotated[FileSegment, Tag(SegmentType.FILE)] + | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)] + | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)] + | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)] + | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)] + | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)] + ), + Discriminator(get_segment_discriminator), +] diff --git a/api/core/variables/types.py b/api/core/variables/types.py index 68d3d8288..e39237dba 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -1,8 +1,27 @@ +from collections.abc import Mapping from enum import StrEnum +from typing import Any, Optional + +from core.file.models import File + + +class ArrayValidation(StrEnum): + """Strategy for validating array elements""" + + # Skip element validation (only check array container) + NONE = "none" + + # Validate the first element (if array is non-empty) + FIRST = "first" + + # Validate all elements in the array. + ALL = "all" class SegmentType(StrEnum): NUMBER = "number" + INTEGER = "integer" + FLOAT = "float" STRING = "string" OBJECT = "object" SECRET = "secret" @@ -19,16 +38,141 @@ class SegmentType(StrEnum): GROUP = "group" - def is_array_type(self): + def is_array_type(self) -> bool: return self in _ARRAY_TYPES + @classmethod + def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]: + """ + Attempt to infer the `SegmentType` based on the Python type of the `value` parameter. + + Returns `None` if no appropriate `SegmentType` can be determined for the given `value`. + For example, this may occur if the input is a generic Python object of type `object`. + """ + + if isinstance(value, list): + elem_types: set[SegmentType] = set() + for i in value: + segment_type = cls.infer_segment_type(i) + if segment_type is None: + return None + + elem_types.add(segment_type) + + if len(elem_types) != 1: + if elem_types.issubset(_NUMERICAL_TYPES): + return SegmentType.ARRAY_NUMBER + return SegmentType.ARRAY_ANY + elif all(i.is_array_type() for i in elem_types): + return SegmentType.ARRAY_ANY + match elem_types.pop(): + case SegmentType.STRING: + return SegmentType.ARRAY_STRING + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: + return SegmentType.ARRAY_NUMBER + case SegmentType.OBJECT: + return SegmentType.ARRAY_OBJECT + case SegmentType.FILE: + return SegmentType.ARRAY_FILE + case SegmentType.NONE: + return SegmentType.ARRAY_ANY + case _: + # This should be unreachable. + raise ValueError(f"not supported value {value}") + if value is None: + return SegmentType.NONE + elif isinstance(value, int) and not isinstance(value, bool): + return SegmentType.INTEGER + elif isinstance(value, float): + return SegmentType.FLOAT + elif isinstance(value, str): + return SegmentType.STRING + elif isinstance(value, dict): + return SegmentType.OBJECT + elif isinstance(value, File): + return SegmentType.FILE + elif isinstance(value, str): + return SegmentType.STRING + else: + return None + + def _validate_array(self, value: Any, array_validation: ArrayValidation) -> bool: + if not isinstance(value, list): + return False + # Skip element validation if array is empty + if len(value) == 0: + return True + if self == SegmentType.ARRAY_ANY: + return True + element_type = _ARRAY_ELEMENT_TYPES_MAPPING[self] + + if array_validation == ArrayValidation.NONE: + return True + elif array_validation == ArrayValidation.FIRST: + return element_type.is_valid(value[0]) + else: + return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value) + + def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool: + """ + Check if a value matches the segment type. + Users of `SegmentType` should call this method, instead of using + `isinstance` manually. + + Args: + value: The value to validate + array_validation: Validation strategy for array types (ignored for non-array types) + + Returns: + True if the value matches the type under the given validation strategy + """ + if self.is_array_type(): + return self._validate_array(value, array_validation) + elif self == SegmentType.NUMBER: + return isinstance(value, (int, float)) + elif self == SegmentType.STRING: + return isinstance(value, str) + elif self == SegmentType.OBJECT: + return isinstance(value, dict) + elif self == SegmentType.SECRET: + return isinstance(value, str) + elif self == SegmentType.FILE: + return isinstance(value, File) + elif self == SegmentType.NONE: + return value is None + else: + raise AssertionError("this statement should be unreachable.") + + def exposed_type(self) -> "SegmentType": + """Returns the type exposed to the frontend. + + The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here. + """ + if self in (SegmentType.INTEGER, SegmentType.FLOAT): + return SegmentType.NUMBER + return self + + +_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = { + # ARRAY_ANY does not have correpond element type. + SegmentType.ARRAY_STRING: SegmentType.STRING, + SegmentType.ARRAY_NUMBER: SegmentType.NUMBER, + SegmentType.ARRAY_OBJECT: SegmentType.OBJECT, + SegmentType.ARRAY_FILE: SegmentType.FILE, +} _ARRAY_TYPES = frozenset( - [ + list(_ARRAY_ELEMENT_TYPES_MAPPING.keys()) + + [ SegmentType.ARRAY_ANY, - SegmentType.ARRAY_STRING, - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_FILE, + ] +) + + +_NUMERICAL_TYPES = frozenset( + [ + SegmentType.NUMBER, + SegmentType.INTEGER, + SegmentType.FLOAT, ] ) diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index b650b1682..a31ebc848 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -1,8 +1,8 @@ from collections.abc import Sequence -from typing import cast +from typing import Annotated, TypeAlias, cast from uuid import uuid4 -from pydantic import Field +from pydantic import Discriminator, Field, Tag from core.helper import encrypter @@ -20,6 +20,7 @@ from .segments import ( ObjectSegment, Segment, StringSegment, + get_segment_discriminator, ) from .types import SegmentType @@ -27,6 +28,10 @@ from .types import SegmentType class Variable(Segment): """ A variable is a segment that has a name. + + It is mainly used to store segments and their selector in VariablePool. + + Note: this class is abstract, you should use subclasses of this class instead. """ id: str = Field( @@ -93,3 +98,28 @@ class FileVariable(FileSegment, Variable): class ArrayFileVariable(ArrayFileSegment, ArrayVariable): pass + + +# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic. +# Use `Variable` for type hinting when serialization is not required. +# +# Note: +# - All variants in `VariableUnion` must inherit from the `Variable` class. +# - The union must include all non-abstract subclasses of `Segment`, except: +VariableUnion: TypeAlias = Annotated[ + ( + Annotated[NoneVariable, Tag(SegmentType.NONE)] + | Annotated[StringVariable, Tag(SegmentType.STRING)] + | Annotated[FloatVariable, Tag(SegmentType.FLOAT)] + | Annotated[IntegerVariable, Tag(SegmentType.INTEGER)] + | Annotated[ObjectVariable, Tag(SegmentType.OBJECT)] + | Annotated[FileVariable, Tag(SegmentType.FILE)] + | Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)] + | Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)] + | Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)] + | Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)] + | Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)] + | Annotated[SecretVariable, Tag(SegmentType.SECRET)] + ), + Discriminator(get_segment_discriminator), +] diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 80dda2632..646a9d340 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -1,7 +1,7 @@ import re from collections import defaultdict from collections.abc import Mapping, Sequence -from typing import Any, Union +from typing import Annotated, Any, Union, cast from pydantic import BaseModel, Field @@ -9,8 +9,9 @@ from core.file import File, FileAttribute, file_manager from core.variables import Segment, SegmentGroup, Variable from core.variables.consts import MIN_SELECTORS_LENGTH from core.variables.segments import FileSegment, NoneSegment +from core.variables.variables import VariableUnion from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from core.workflow.enums import SystemVariableKey +from core.workflow.system_variable import SystemVariable from factories import variable_factory VariableValue = Union[str, int, float, dict, list, File] @@ -23,31 +24,31 @@ class VariablePool(BaseModel): # The first element of the selector is the node id, it's the first-level key in the dictionary. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. - variable_dictionary: dict[str, dict[int, Segment]] = Field( + variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field( description="Variables mapping", default=defaultdict(dict), ) - # TODO: This user inputs is not used for pool. + + # The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere. user_inputs: Mapping[str, Any] = Field( description="User inputs", default_factory=dict, ) - system_variables: Mapping[SystemVariableKey, Any] = Field( + system_variables: SystemVariable = Field( description="System variables", - default_factory=dict, ) - environment_variables: Sequence[Variable] = Field( + environment_variables: Sequence[VariableUnion] = Field( description="Environment variables.", default_factory=list, ) - conversation_variables: Sequence[Variable] = Field( + conversation_variables: Sequence[VariableUnion] = Field( description="Conversation variables.", default_factory=list, ) def model_post_init(self, context: Any, /) -> None: - for key, value in self.system_variables.items(): - self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) + # Create a mapping from field names to SystemVariableKey enum values + self._add_system_variables(self.system_variables) # Add environment variables to the variable pool for var in self.environment_variables: self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) @@ -83,8 +84,22 @@ class VariablePool(BaseModel): segment = variable_factory.build_segment(value) variable = variable_factory.segment_to_variable(segment=segment, selector=selector) - hash_key = hash(tuple(selector[1:])) - self.variable_dictionary[selector[0]][hash_key] = variable + key, hash_key = self._selector_to_keys(selector) + # Based on the definition of `VariableUnion`, + # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. + self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable) + + @classmethod + def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]: + return selector[0], hash(tuple(selector[1:])) + + def _has(self, selector: Sequence[str]) -> bool: + key, hash_key = self._selector_to_keys(selector) + if key not in self.variable_dictionary: + return False + if hash_key not in self.variable_dictionary[key]: + return False + return True def get(self, selector: Sequence[str], /) -> Segment | None: """ @@ -102,8 +117,8 @@ class VariablePool(BaseModel): if len(selector) < MIN_SELECTORS_LENGTH: return None - hash_key = hash(tuple(selector[1:])) - value = self.variable_dictionary[selector[0]].get(hash_key) + key, hash_key = self._selector_to_keys(selector) + value: Segment | None = self.variable_dictionary[key].get(hash_key) if value is None: selector, attr = selector[:-1], selector[-1] @@ -136,8 +151,9 @@ class VariablePool(BaseModel): if len(selector) == 1: self.variable_dictionary[selector[0]] = {} return + key, hash_key = self._selector_to_keys(selector) hash_key = hash(tuple(selector[1:])) - self.variable_dictionary[selector[0]].pop(hash_key, None) + self.variable_dictionary[key].pop(hash_key, None) def convert_template(self, template: str, /): parts = VARIABLE_PATTERN.split(template) @@ -154,3 +170,20 @@ class VariablePool(BaseModel): if isinstance(segment, FileSegment): return segment return None + + def _add_system_variables(self, system_variable: SystemVariable): + sys_var_mapping = system_variable.to_dict() + for key, value in sys_var_mapping.items(): + if value is None: + continue + selector = (SYSTEM_VARIABLE_NODE_ID, key) + # If the system variable already exists, do not add it again. + # This ensures that we can keep the id of the system variables intact. + if self._has(selector): + continue + self.add(selector, value) # type: ignore + + @classmethod + def empty(cls) -> "VariablePool": + """Create an empty variable pool.""" + return cls(system_variables=SystemVariable.empty()) diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py index afc09bfac..a62ffe46c 100644 --- a/api/core/workflow/graph_engine/entities/graph_runtime_state.py +++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py @@ -17,8 +17,12 @@ class GraphRuntimeState(BaseModel): """total tokens""" llm_usage: LLMUsage = LLMUsage.empty_usage() """llm usage info""" + + # The `outputs` field stores the final output values generated by executing workflows or chatflows. + # + # Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent + # after a serialization and deserialization round trip. outputs: dict[str, Any] = {} - """outputs""" node_run_steps: int = 0 """node run steps""" diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py index 3f4a5edab..d04e0bfae 100644 --- a/api/core/workflow/nodes/loop/entities.py +++ b/api/core/workflow/nodes/loop/entities.py @@ -1,11 +1,29 @@ from collections.abc import Mapping -from typing import Any, Literal, Optional +from typing import Annotated, Any, Literal, Optional -from pydantic import BaseModel, Field +from pydantic import AfterValidator, BaseModel, Field +from core.variables.types import SegmentType from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData from core.workflow.utils.condition.entities import Condition +_VALID_VAR_TYPE = frozenset( + [ + SegmentType.STRING, + SegmentType.NUMBER, + SegmentType.OBJECT, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_OBJECT, + ] +) + + +def _is_valid_var_type(seg_type: SegmentType) -> SegmentType: + if seg_type not in _VALID_VAR_TYPE: + raise ValueError(...) + return seg_type + class LoopVariableData(BaseModel): """ @@ -13,7 +31,7 @@ class LoopVariableData(BaseModel): """ label: str - var_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] + var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)] value_type: Literal["variable", "constant"] value: Optional[Any | list[str]] = None diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 11fd7b6c2..20501d031 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -7,14 +7,9 @@ from typing import TYPE_CHECKING, Any, Literal, cast from configs import dify_config from core.variables import ( - ArrayNumberSegment, - ArrayObjectSegment, - ArrayStringSegment, IntegerSegment, - ObjectSegment, Segment, SegmentType, - StringSegment, ) from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -39,6 +34,7 @@ from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.loop.entities import LoopNodeData from core.workflow.utils.condition.processor import ConditionProcessor +from factories.variable_factory import TypeMismatchError, build_segment_with_type if TYPE_CHECKING: from core.workflow.entities.variable_pool import VariablePool @@ -505,23 +501,21 @@ class LoopNode(BaseNode[LoopNodeData]): return variable_mapping @staticmethod - def _get_segment_for_constant(var_type: str, value: Any) -> Segment: + def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment: """Get the appropriate segment type for a constant value.""" - segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = { - "string": (StringSegment, SegmentType.STRING), - "number": (IntegerSegment, SegmentType.NUMBER), - "object": (ObjectSegment, SegmentType.OBJECT), - "array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING), - "array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER), - "array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT), - } if var_type in ["array[string]", "array[number]", "array[object]"]: - if value: + if value and isinstance(value, str): value = json.loads(value) else: value = [] - segment_info = segment_mapping.get(var_type) - if not segment_info: - raise ValueError(f"Invalid variable type: {var_type}") - segment_class, value_type = segment_info - return segment_class(value=value, value_type=value_type) + try: + return build_segment_with_type(var_type, value) + except TypeMismatchError as type_exc: + # Attempt to parse the value as a JSON-encoded string, if applicable. + if not isinstance(value, str): + raise + try: + value = json.loads(value) + except ValueError: + raise type_exc + return build_segment_with_type(var_type, value) diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 5ee9bc331..e21559188 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -16,7 +16,7 @@ class StartNode(BaseNode[StartNodeData]): def _run(self) -> NodeRunResult: node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables + system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() # TODO: System variables should be directly accessible, no need for special handling # Set system variables as node outputs. diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index be5083c9c..1864b1378 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -130,6 +130,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]): def get_zero_value(t: SegmentType): + # TODO(QuantumGhost): this should be a method of `SegmentType`. match t: case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: return variable_factory.build_segment([]) @@ -137,6 +138,10 @@ def get_zero_value(t: SegmentType): return variable_factory.build_segment({}) case SegmentType.STRING: return variable_factory.build_segment("") + case SegmentType.INTEGER: + return variable_factory.build_segment(0) + case SegmentType.FLOAT: + return variable_factory.build_segment(0.0) case SegmentType.NUMBER: return variable_factory.build_segment(0) case _: diff --git a/api/core/workflow/nodes/variable_assigner/v2/constants.py b/api/core/workflow/nodes/variable_assigner/v2/constants.py index 3797bfa77..7f760e5ba 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/constants.py +++ b/api/core/workflow/nodes/variable_assigner/v2/constants.py @@ -1,5 +1,6 @@ from core.variables import SegmentType +# Note: This mapping is duplicated with `get_zero_value`. Consider refactoring to avoid redundancy. EMPTY_VALUE_MAPPING = { SegmentType.STRING: "", SegmentType.NUMBER: 0, diff --git a/api/core/workflow/nodes/variable_assigner/v2/helpers.py b/api/core/workflow/nodes/variable_assigner/v2/helpers.py index 8fb2a2738..7a20975b1 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/helpers.py +++ b/api/core/workflow/nodes/variable_assigner/v2/helpers.py @@ -10,10 +10,16 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation): case Operation.OVER_WRITE | Operation.CLEAR: return True case Operation.SET: - return variable_type in {SegmentType.OBJECT, SegmentType.STRING, SegmentType.NUMBER} + return variable_type in { + SegmentType.OBJECT, + SegmentType.STRING, + SegmentType.NUMBER, + SegmentType.INTEGER, + SegmentType.FLOAT, + } case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE: # Only number variable can be added, subtracted, multiplied or divided - return variable_type == SegmentType.NUMBER + return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT} case Operation.APPEND | Operation.EXTEND: # Only array variable can be appended or extended return variable_type in { @@ -46,7 +52,7 @@ def is_constant_input_supported(*, variable_type: SegmentType, operation: Operat match variable_type: case SegmentType.STRING | SegmentType.OBJECT: return operation in {Operation.OVER_WRITE, Operation.SET} - case SegmentType.NUMBER: + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: return operation in { Operation.OVER_WRITE, Operation.SET, @@ -66,7 +72,7 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va case SegmentType.STRING: return isinstance(value, str) - case SegmentType.NUMBER: + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: if not isinstance(value, int | float): return False if operation == Operation.DIVIDE and value == 0: diff --git a/api/core/workflow/system_variable.py b/api/core/workflow/system_variable.py new file mode 100644 index 000000000..df90c1659 --- /dev/null +++ b/api/core/workflow/system_variable.py @@ -0,0 +1,89 @@ +from collections.abc import Sequence +from typing import Any + +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator + +from core.file.models import File +from core.workflow.enums import SystemVariableKey + + +class SystemVariable(BaseModel): + """A model for managing system variables. + + Fields with a value of `None` are treated as absent and will not be included + in the variable pool. + """ + + model_config = ConfigDict( + extra="forbid", + serialize_by_alias=True, + validate_by_alias=True, + ) + + user_id: str | None = None + + # Ideally, `app_id` and `workflow_id` should be required and not `None`. + # However, there are scenarios in the codebase where these fields are not set. + # To maintain compatibility, they are marked as optional here. + app_id: str | None = None + workflow_id: str | None = None + + files: Sequence[File] = Field(default_factory=list) + + # NOTE: The `workflow_execution_id` field was previously named `workflow_run_id`. + # To maintain compatibility with existing workflows, it must be serialized + # as `workflow_run_id` in dictionaries or JSON objects, and also referenced + # as `workflow_run_id` in the variable pool. + workflow_execution_id: str | None = Field( + validation_alias=AliasChoices("workflow_execution_id", "workflow_run_id"), + serialization_alias="workflow_run_id", + default=None, + ) + # Chatflow related fields. + query: str | None = None + conversation_id: str | None = None + dialogue_count: int | None = None + + @model_validator(mode="before") + @classmethod + def validate_json_fields(cls, data): + if isinstance(data, dict): + # For JSON validation, only allow workflow_run_id + if "workflow_execution_id" in data and "workflow_run_id" not in data: + # This is likely from direct instantiation, allow it + return data + elif "workflow_execution_id" in data and "workflow_run_id" in data: + # Both present, remove workflow_execution_id + data = data.copy() + data.pop("workflow_execution_id") + return data + return data + + @classmethod + def empty(cls) -> "SystemVariable": + return cls() + + def to_dict(self) -> dict[SystemVariableKey, Any]: + # NOTE: This method is provided for compatibility with legacy code. + # New code should use the `SystemVariable` object directly instead of converting + # it to a dictionary, as this conversion results in the loss of type information + # for each key, making static analysis more difficult. + + d: dict[SystemVariableKey, Any] = { + SystemVariableKey.FILES: self.files, + } + if self.user_id is not None: + d[SystemVariableKey.USER_ID] = self.user_id + if self.app_id is not None: + d[SystemVariableKey.APP_ID] = self.app_id + if self.workflow_id is not None: + d[SystemVariableKey.WORKFLOW_ID] = self.workflow_id + if self.workflow_execution_id is not None: + d[SystemVariableKey.WORKFLOW_EXECUTION_ID] = self.workflow_execution_id + if self.query is not None: + d[SystemVariableKey.QUERY] = self.query + if self.conversation_id is not None: + d[SystemVariableKey.CONVERSATION_ID] = self.conversation_id + if self.dialogue_count is not None: + d[SystemVariableKey.DIALOGUE_COUNT] = self.dialogue_count + return d diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index 0aab2426a..50ff73397 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -26,6 +26,7 @@ from core.workflow.entities.workflow_node_execution import ( from core.workflow.enums import SystemVariableKey from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry from libs.datetime_utils import naive_utc_now @@ -43,7 +44,7 @@ class WorkflowCycleManager: self, *, application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], - workflow_system_variables: dict[SystemVariableKey, Any], + workflow_system_variables: SystemVariable, workflow_info: CycleManagerWorkflowInfo, workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, @@ -56,17 +57,22 @@ class WorkflowCycleManager: def handle_workflow_run_start(self) -> WorkflowExecution: inputs = {**self._application_generate_entity.inputs} - for key, value in (self._workflow_system_variables or {}).items(): - if key.value == "conversation": - continue - inputs[f"sys.{key.value}"] = value + + # Iterate over SystemVariable fields using Pydantic's model_fields + if self._workflow_system_variables: + for field_name, value in self._workflow_system_variables.to_dict().items(): + if field_name == SystemVariableKey.CONVERSATION_ID: + continue + inputs[f"sys.{field_name}"] = value # handle special values inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) # init workflow run # TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this - execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_EXECUTION_ID) or uuid4()) + execution_id = str( + self._workflow_system_variables.workflow_execution_id if self._workflow_system_variables else None + ) or str(uuid4()) execution = WorkflowExecution.new( id_=execution_id, workflow_id=self._workflow_info.workflow_id, diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 2868dcb7d..1399efcdb 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -21,6 +21,7 @@ from core.workflow.nodes import NodeType from core.workflow.nodes.base import BaseNode from core.workflow.nodes.event import NodeEvent from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from factories import file_factory from models.enums import UserFrom @@ -254,7 +255,7 @@ class WorkflowEntry: # init variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, environment_variables=[], ) diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 250ee4695..39ebd009d 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -91,9 +91,13 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen result = StringVariable.model_validate(mapping) case SegmentType.SECRET: result = SecretVariable.model_validate(mapping) - case SegmentType.NUMBER if isinstance(value, int): + case SegmentType.NUMBER | SegmentType.INTEGER if isinstance(value, int): + mapping = dict(mapping) + mapping["value_type"] = SegmentType.INTEGER result = IntegerVariable.model_validate(mapping) - case SegmentType.NUMBER if isinstance(value, float): + case SegmentType.NUMBER | SegmentType.FLOAT if isinstance(value, float): + mapping = dict(mapping) + mapping["value_type"] = SegmentType.FLOAT result = FloatVariable.model_validate(mapping) case SegmentType.NUMBER if not isinstance(value, float | int): raise VariableError(f"invalid number value {value}") @@ -119,6 +123,8 @@ def infer_segment_type_from_value(value: Any, /) -> SegmentType: def build_segment(value: Any, /) -> Segment: + # NOTE: If you have runtime type information available, consider using the `build_segment_with_type` + # below if value is None: return NoneSegment() if isinstance(value, str): @@ -134,12 +140,17 @@ def build_segment(value: Any, /) -> Segment: if isinstance(value, list): items = [build_segment(item) for item in value] types = {item.value_type for item in items} - if len(types) != 1 or all(isinstance(item, ArraySegment) for item in items): + if all(isinstance(item, ArraySegment) for item in items): return ArrayAnySegment(value=value) + elif len(types) != 1: + if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}): + return ArrayNumberSegment(value=value) + return ArrayAnySegment(value=value) + match types.pop(): case SegmentType.STRING: return ArrayStringSegment(value=value) - case SegmentType.NUMBER: + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: return ArrayNumberSegment(value=value) case SegmentType.OBJECT: return ArrayObjectSegment(value=value) @@ -153,6 +164,22 @@ def build_segment(value: Any, /) -> Segment: raise ValueError(f"not supported value {value}") +_segment_factory: Mapping[SegmentType, type[Segment]] = { + SegmentType.NONE: NoneSegment, + SegmentType.STRING: StringSegment, + SegmentType.INTEGER: IntegerSegment, + SegmentType.FLOAT: FloatSegment, + SegmentType.FILE: FileSegment, + SegmentType.OBJECT: ObjectSegment, + # Array types + SegmentType.ARRAY_ANY: ArrayAnySegment, + SegmentType.ARRAY_STRING: ArrayStringSegment, + SegmentType.ARRAY_NUMBER: ArrayNumberSegment, + SegmentType.ARRAY_OBJECT: ArrayObjectSegment, + SegmentType.ARRAY_FILE: ArrayFileSegment, +} + + def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: """ Build a segment with explicit type checking. @@ -190,7 +217,7 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: if segment_type == SegmentType.NONE: return NoneSegment() else: - raise TypeMismatchError(f"Expected {segment_type}, but got None") + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None") # Handle empty list special case for array types if isinstance(value, list) and len(value) == 0: @@ -205,21 +232,25 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: elif segment_type == SegmentType.ARRAY_FILE: return ArrayFileSegment(value=value) else: - raise TypeMismatchError(f"Expected {segment_type}, but got empty list") - - # Build segment using existing logic to infer actual type - inferred_segment = build_segment(value) - inferred_type = inferred_segment.value_type + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list") + inferred_type = SegmentType.infer_segment_type(value) # Type compatibility checking + if inferred_type is None: + raise TypeMismatchError( + f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}" + ) if inferred_type == segment_type: - return inferred_segment - - # Type mismatch - raise error with descriptive message - raise TypeMismatchError( - f"Type mismatch: expected {segment_type}, but value '{value}' " - f"(type: {type(value).__name__}) corresponds to {inferred_type}" - ) + segment_class = _segment_factory[segment_type] + return segment_class(value_type=segment_type, value=value) + elif segment_type == SegmentType.NUMBER and inferred_type in ( + SegmentType.INTEGER, + SegmentType.FLOAT, + ): + segment_class = _segment_factory[inferred_type] + return segment_class(value_type=inferred_type, value=value) + else: + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}") def segment_to_variable( @@ -247,6 +278,6 @@ def segment_to_variable( name=name, description=description, value=segment.value, - selector=selector, + selector=list(selector), ), ) diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py new file mode 100644 index 000000000..8288bd54a --- /dev/null +++ b/api/fields/_value_type_serializer.py @@ -0,0 +1,15 @@ +from typing import TypedDict + +from core.variables.segments import Segment +from core.variables.types import SegmentType + + +class _VarTypedDict(TypedDict, total=False): + value_type: SegmentType + + +def serialize_value_type(v: _VarTypedDict | Segment) -> str: + if isinstance(v, Segment): + return v.value_type.exposed_type().value + else: + return v["value_type"].exposed_type().value diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index 71785e7d6..c5a0c9a49 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -2,10 +2,12 @@ from flask_restful import fields from libs.helper import TimestampField +from ._value_type_serializer import serialize_value_type + conversation_variable_fields = { "id": fields.String, "name": fields.String, - "value_type": fields.String(attribute="value_type.value"), + "value_type": fields.String(attribute=serialize_value_type), "value": fields.String, "description": fields.String, "created_at": TimestampField, diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index f00ea71c5..930e59cc1 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -5,6 +5,8 @@ from core.variables import SecretVariable, SegmentType, Variable from fields.member_fields import simple_account_fields from libs.helper import TimestampField +from ._value_type_serializer import serialize_value_type + ENVIRONMENT_VARIABLE_SUPPORTED_TYPES = (SegmentType.STRING, SegmentType.NUMBER, SegmentType.SECRET) @@ -24,11 +26,16 @@ class EnvironmentVariableField(fields.Raw): "id": value.id, "name": value.name, "value": value.value, - "value_type": value.value_type.value, + "value_type": value.value_type.exposed_type().value, "description": value.description, } if isinstance(value, dict): - value_type = value.get("value_type") + value_type_str = value.get("value_type") + if not isinstance(value_type_str, str): + raise TypeError( + f"unexpected type for value_type field, value={value_type_str}, type={type(value_type_str)}" + ) + value_type = SegmentType(value_type_str).exposed_type() if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES: raise ValueError(f"Unsupported environment variable value type: {value_type}") return value @@ -37,7 +44,7 @@ class EnvironmentVariableField(fields.Raw): conversation_variable_fields = { "id": fields.String, "name": fields.String, - "value_type": fields.String(attribute="value_type.value"), + "value_type": fields.String(attribute=serialize_value_type), "value": fields.Raw, "description": fields.String, } diff --git a/api/models/workflow.py b/api/models/workflow.py index 77d48bec4..993085920 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -12,6 +12,7 @@ from sqlalchemy import orm from core.file.constants import maybe_file_object from core.file.models import File from core.variables import utils as variable_utils +from core.variables.variables import FloatVariable, IntegerVariable, StringVariable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.nodes.enums import NodeType from factories.variable_factory import TypeMismatchError, build_segment_with_type @@ -347,7 +348,7 @@ class Workflow(Base): ) @property - def environment_variables(self) -> Sequence[Variable]: + def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: # TODO: find some way to init `self._environment_variables` when instance created. if self._environment_variables is None: self._environment_variables = "{}" @@ -367,11 +368,15 @@ class Workflow(Base): def decrypt_func(var): if isinstance(var, SecretVariable): return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) - else: + elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)): return var + else: + raise AssertionError("this statement should be unreachable.") - results = list(map(decrypt_func, results)) - return results + decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list( + map(decrypt_func, results) + ) + return decrypted_results @environment_variables.setter def environment_variables(self, value: Sequence[Variable]): diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 0149d5034..677bc7423 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -3,7 +3,7 @@ import time import uuid from collections.abc import Callable, Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import Any, Optional +from typing import Any, Optional, cast from uuid import uuid4 from sqlalchemy import select @@ -15,10 +15,10 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.file import File from core.repositories import DifyCoreRepositoryFactory from core.variables import Variable +from core.variables.variables import VariableUnion from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes import NodeType @@ -28,6 +28,7 @@ from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.event.types import NodeEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db @@ -369,7 +370,7 @@ class WorkflowService: else: variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs=user_inputs, environment_variables=draft_workflow.environment_variables, conversation_variables=[], @@ -685,36 +686,30 @@ def _setup_variable_pool( ): # Only inject system variables for START node type. if node_type == NodeType.START: - # Create a variable pool. - system_inputs: dict[SystemVariableKey, Any] = { - # From inputs: - SystemVariableKey.FILES: files, - SystemVariableKey.USER_ID: user_id, - # From workflow model - SystemVariableKey.APP_ID: workflow.app_id, - SystemVariableKey.WORKFLOW_ID: workflow.id, - # Randomly generated. - SystemVariableKey.WORKFLOW_EXECUTION_ID: str(uuid.uuid4()), - } + system_variable = SystemVariable( + user_id=user_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + files=files or [], + workflow_execution_id=str(uuid.uuid4()), + ) # Only add chatflow-specific variables for non-workflow types if workflow.type != WorkflowType.WORKFLOW.value: - system_inputs.update( - { - SystemVariableKey.QUERY: query, - SystemVariableKey.CONVERSATION_ID: conversation_id, - SystemVariableKey.DIALOGUE_COUNT: 0, - } - ) + system_variable.query = query + system_variable.conversation_id = conversation_id + system_variable.dialogue_count = 0 else: - system_inputs = {} + system_variable = SystemVariable.empty() # init variable pool variable_pool = VariablePool( - system_variables=system_inputs, + system_variables=system_variable, user_inputs=user_inputs, environment_variables=workflow.environment_variables, - conversation_variables=conversation_variables, + # Based on the definition of `VariableUnion`, + # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. + conversation_variables=cast(list[VariableUnion], conversation_variables), # ) return variable_pool diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 13d78c2d8..90bb04f64 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -9,12 +9,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.entities import CodeNodeData +from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.workflow import WorkflowType from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock @@ -50,7 +50,7 @@ def init_code_node(code_config: dict): # construct variable pool variable_pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 1ab0cc245..50e726feb 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -6,11 +6,11 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.http_request.node import HttpRequestNode +from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.workflow import WorkflowType from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock @@ -44,7 +44,7 @@ def init_http_node(config: dict): # construct variable pool variable_pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 8acaa54b9..ff119b748 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -13,12 +13,12 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import AssistantPromptMessage from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.llm.node import LLMNode +from core.workflow.system_variable import SystemVariable from extensions.ext_database import db from models.enums import UserFrom from models.workflow import WorkflowType @@ -62,12 +62,14 @@ def init_llm_node(config: dict) -> LLMNode: # construct variable pool variable_pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather today?", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, + system_variables=SystemVariable( + user_id="aaa", + app_id=app_id, + workflow_id=workflow_id, + files=[], + query="what's the weather today?", + conversation_id="abababa", + ), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 0df8e8b14..dd8466afa 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -8,11 +8,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.model_runtime.entities import AssistantPromptMessage from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from core.workflow.system_variable import SystemVariable from extensions.ext_database import db from models.enums import UserFrom from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config @@ -64,12 +64,9 @@ def init_parameter_extractor_node(config: dict): # construct variable pool variable_pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather in SF", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, + system_variables=SystemVariable( + user_id="aaa", files=[], query="what's the weather in SF", conversation_id="abababa" + ), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index a5f2677a5..1f617fc92 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -6,11 +6,11 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.workflow import WorkflowType from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock @@ -61,7 +61,7 @@ def test_execute_code(setup_code_executor_mock): # construct variable pool variable_pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 039beedaf..6907e0163 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -6,12 +6,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.utils.configuration import ToolParameterConfigurationManager from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.event.event import RunCompletedEvent from core.workflow.nodes.tool.tool_node import ToolNode +from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.workflow import WorkflowType @@ -44,7 +44,7 @@ def init_tool_node(config: dict): # construct variable pool variable_pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 1b035d01a..cdc261fd4 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -1,14 +1,49 @@ +import dataclasses + +from pydantic import BaseModel + +from core.file import File, FileTransferMethod, FileType from core.helper import encrypter -from core.variables import SecretVariable, StringVariable +from core.variables.segments import ( + ArrayAnySegment, + ArrayFileSegment, + ArrayNumberSegment, + ArrayObjectSegment, + ArrayStringSegment, + FileSegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, + Segment, + SegmentUnion, + StringSegment, + get_segment_discriminator, +) +from core.variables.types import SegmentType +from core.variables.variables import ( + ArrayAnyVariable, + ArrayFileVariable, + ArrayNumberVariable, + ArrayObjectVariable, + ArrayStringVariable, + FileVariable, + FloatVariable, + IntegerVariable, + NoneVariable, + ObjectVariable, + SecretVariable, + StringVariable, + Variable, + VariableUnion, +) from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey +from core.workflow.system_variable import SystemVariable def test_segment_group_to_text(): variable_pool = VariablePool( - system_variables={ - SystemVariableKey("user_id"): "fake-user-id", - }, + system_variables=SystemVariable(user_id="fake-user-id"), user_inputs={}, environment_variables=[ SecretVariable(name="secret_key", value="fake-secret-key"), @@ -30,7 +65,7 @@ def test_segment_group_to_text(): def test_convert_constant_to_segment_group(): variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable(user_id="1", app_id="1", workflow_id="1"), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -43,9 +78,7 @@ def test_convert_constant_to_segment_group(): def test_convert_variable_to_segment_group(): variable_pool = VariablePool( - system_variables={ - SystemVariableKey("user_id"): "fake-user-id", - }, + system_variables=SystemVariable(user_id="fake-user-id"), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -56,3 +89,297 @@ def test_convert_variable_to_segment_group(): assert segments_group.log == "fake-user-id" assert isinstance(segments_group.value[0], StringVariable) assert segments_group.value[0].value == "fake-user-id" + + +class _Segments(BaseModel): + segments: list[SegmentUnion] + + +class _Variables(BaseModel): + variables: list[VariableUnion] + + +def create_test_file( + file_type: FileType = FileType.DOCUMENT, + transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE, + filename: str = "test.txt", + extension: str = ".txt", + mime_type: str = "text/plain", + size: int = 1024, +) -> File: + """Factory function to create File objects for testing""" + return File( + tenant_id="test-tenant", + type=file_type, + transfer_method=transfer_method, + filename=filename, + extension=extension, + mime_type=mime_type, + size=size, + related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None, + remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None, + storage_key="test-storage-key", + ) + + +class TestSegmentDumpAndLoad: + """Test suite for segment and variable serialization/deserialization""" + + def test_segments(self): + """Test basic segment serialization compatibility""" + model = _Segments(segments=[IntegerSegment(value=1), StringSegment(value="a")]) + json = model.model_dump_json() + print("Json: ", json) + loaded = _Segments.model_validate_json(json) + assert loaded == model + + def test_segment_number(self): + """Test number segment serialization compatibility""" + model = _Segments(segments=[IntegerSegment(value=1), FloatSegment(value=1.0)]) + json = model.model_dump_json() + print("Json: ", json) + loaded = _Segments.model_validate_json(json) + assert loaded == model + + def test_variables(self): + """Test variable serialization compatibility""" + model = _Variables(variables=[IntegerVariable(value=1, name="int"), StringVariable(value="a", name="str")]) + json = model.model_dump_json() + print("Json: ", json) + restored = _Variables.model_validate_json(json) + assert restored == model + + def test_all_segments_serialization(self): + """Test serialization/deserialization of all segment types""" + # Create one instance of each segment type + test_file = create_test_file() + + all_segments: list[SegmentUnion] = [ + NoneSegment(), + StringSegment(value="test string"), + IntegerSegment(value=42), + FloatSegment(value=3.14), + ObjectSegment(value={"key": "value", "number": 123}), + FileSegment(value=test_file), + ArrayAnySegment(value=[1, "string", 3.14, {"key": "value"}]), + ArrayStringSegment(value=["hello", "world"]), + ArrayNumberSegment(value=[1, 2.5, 3]), + ArrayObjectSegment(value=[{"id": 1}, {"id": 2}]), + ArrayFileSegment(value=[]), # Empty array to avoid file complexity + ] + + # Test serialization and deserialization + model = _Segments(segments=all_segments) + json_str = model.model_dump_json() + loaded = _Segments.model_validate_json(json_str) + + # Verify all segments are preserved + assert len(loaded.segments) == len(all_segments) + + for original, loaded_segment in zip(all_segments, loaded.segments): + assert type(loaded_segment) == type(original) + assert loaded_segment.value_type == original.value_type + + # For file segments, compare key properties instead of exact equality + if isinstance(original, FileSegment) and isinstance(loaded_segment, FileSegment): + orig_file = original.value + loaded_file = loaded_segment.value + assert isinstance(orig_file, File) + assert isinstance(loaded_file, File) + assert loaded_file.tenant_id == orig_file.tenant_id + assert loaded_file.type == orig_file.type + assert loaded_file.filename == orig_file.filename + else: + assert loaded_segment.value == original.value + + def test_all_variables_serialization(self): + """Test serialization/deserialization of all variable types""" + # Create one instance of each variable type + test_file = create_test_file() + + all_variables: list[VariableUnion] = [ + NoneVariable(name="none_var"), + StringVariable(value="test string", name="string_var"), + IntegerVariable(value=42, name="int_var"), + FloatVariable(value=3.14, name="float_var"), + ObjectVariable(value={"key": "value", "number": 123}, name="object_var"), + FileVariable(value=test_file, name="file_var"), + ArrayAnyVariable(value=[1, "string", 3.14, {"key": "value"}], name="array_any_var"), + ArrayStringVariable(value=["hello", "world"], name="array_string_var"), + ArrayNumberVariable(value=[1, 2.5, 3], name="array_number_var"), + ArrayObjectVariable(value=[{"id": 1}, {"id": 2}], name="array_object_var"), + ArrayFileVariable(value=[], name="array_file_var"), # Empty array to avoid file complexity + ] + + # Test serialization and deserialization + model = _Variables(variables=all_variables) + json_str = model.model_dump_json() + loaded = _Variables.model_validate_json(json_str) + + # Verify all variables are preserved + assert len(loaded.variables) == len(all_variables) + + for original, loaded_variable in zip(all_variables, loaded.variables): + assert type(loaded_variable) == type(original) + assert loaded_variable.value_type == original.value_type + assert loaded_variable.name == original.name + + # For file variables, compare key properties instead of exact equality + if isinstance(original, FileVariable) and isinstance(loaded_variable, FileVariable): + orig_file = original.value + loaded_file = loaded_variable.value + assert isinstance(orig_file, File) + assert isinstance(loaded_file, File) + assert loaded_file.tenant_id == orig_file.tenant_id + assert loaded_file.type == orig_file.type + assert loaded_file.filename == orig_file.filename + else: + assert loaded_variable.value == original.value + + def test_segment_discriminator_function_for_segment_types(self): + """Test the segment discriminator function""" + + @dataclasses.dataclass + class TestCase: + segment: Segment + expected_segment_type: SegmentType + + file1 = create_test_file() + file2 = create_test_file(filename="test2.txt") + + cases = [ + TestCase( + NoneSegment(), + SegmentType.NONE, + ), + TestCase( + StringSegment(value=""), + SegmentType.STRING, + ), + TestCase( + FloatSegment(value=0.0), + SegmentType.FLOAT, + ), + TestCase( + IntegerSegment(value=0), + SegmentType.INTEGER, + ), + TestCase( + ObjectSegment(value={}), + SegmentType.OBJECT, + ), + TestCase( + FileSegment(value=file1), + SegmentType.FILE, + ), + TestCase( + ArrayAnySegment(value=[0, 0.0, ""]), + SegmentType.ARRAY_ANY, + ), + TestCase( + ArrayStringSegment(value=[""]), + SegmentType.ARRAY_STRING, + ), + TestCase( + ArrayNumberSegment(value=[0, 0.0]), + SegmentType.ARRAY_NUMBER, + ), + TestCase( + ArrayObjectSegment(value=[{}]), + SegmentType.ARRAY_OBJECT, + ), + TestCase( + ArrayFileSegment(value=[file1, file2]), + SegmentType.ARRAY_FILE, + ), + ] + + for test_case in cases: + segment = test_case.segment + assert get_segment_discriminator(segment) == test_case.expected_segment_type, ( + f"get_segment_discriminator failed for type {type(segment)}" + ) + model_dict = segment.model_dump(mode="json") + assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, ( + f"get_segment_discriminator failed for serialized form of type {type(segment)}" + ) + + def test_variable_discriminator_function_for_variable_types(self): + """Test the variable discriminator function""" + + @dataclasses.dataclass + class TestCase: + variable: Variable + expected_segment_type: SegmentType + + file1 = create_test_file() + file2 = create_test_file(filename="test2.txt") + + cases = [ + TestCase( + NoneVariable(name="none_var"), + SegmentType.NONE, + ), + TestCase( + StringVariable(value="test", name="string_var"), + SegmentType.STRING, + ), + TestCase( + FloatVariable(value=0.0, name="float_var"), + SegmentType.FLOAT, + ), + TestCase( + IntegerVariable(value=0, name="int_var"), + SegmentType.INTEGER, + ), + TestCase( + ObjectVariable(value={}, name="object_var"), + SegmentType.OBJECT, + ), + TestCase( + FileVariable(value=file1, name="file_var"), + SegmentType.FILE, + ), + TestCase( + SecretVariable(value="secret", name="secret_var"), + SegmentType.SECRET, + ), + TestCase( + ArrayAnyVariable(value=[0, 0.0, ""], name="array_any_var"), + SegmentType.ARRAY_ANY, + ), + TestCase( + ArrayStringVariable(value=[""], name="array_string_var"), + SegmentType.ARRAY_STRING, + ), + TestCase( + ArrayNumberVariable(value=[0, 0.0], name="array_number_var"), + SegmentType.ARRAY_NUMBER, + ), + TestCase( + ArrayObjectVariable(value=[{}], name="array_object_var"), + SegmentType.ARRAY_OBJECT, + ), + TestCase( + ArrayFileVariable(value=[file1, file2], name="array_file_var"), + SegmentType.ARRAY_FILE, + ), + ] + + for test_case in cases: + variable = test_case.variable + assert get_segment_discriminator(variable) == test_case.expected_segment_type, ( + f"get_segment_discriminator failed for type {type(variable)}" + ) + model_dict = variable.model_dump(mode="json") + assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, ( + f"get_segment_discriminator failed for serialized form of type {type(variable)}" + ) + + def test_invlaid_value_for_discriminator(self): + # Test invalid cases + assert get_segment_discriminator({"value_type": "invalid"}) is None + assert get_segment_discriminator({}) is None + assert get_segment_discriminator("not_a_dict") is None + assert get_segment_discriminator(42) is None + assert get_segment_discriminator(object) is None diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py new file mode 100644 index 000000000..64d0d8c7e --- /dev/null +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -0,0 +1,60 @@ +from core.variables.types import SegmentType + + +class TestSegmentTypeIsArrayType: + """ + Test class for SegmentType.is_array_type method. + + Provides comprehensive coverage of all SegmentType values to ensure + correct identification of array and non-array types. + """ + + def test_is_array_type(self): + """ + Test that all SegmentType enum values are covered in our test cases. + + Ensures comprehensive coverage by verifying that every SegmentType + value is tested for the is_array_type method. + """ + # Arrange + all_segment_types = set(SegmentType) + expected_array_types = [ + SegmentType.ARRAY_ANY, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_OBJECT, + SegmentType.ARRAY_FILE, + ] + expected_non_array_types = [ + SegmentType.INTEGER, + SegmentType.FLOAT, + SegmentType.NUMBER, + SegmentType.STRING, + SegmentType.OBJECT, + SegmentType.SECRET, + SegmentType.FILE, + SegmentType.NONE, + SegmentType.GROUP, + ] + + for seg_type in expected_array_types: + assert seg_type.is_array_type() + + for seg_type in expected_non_array_types: + assert not seg_type.is_array_type() + + # Act & Assert + covered_types = set(expected_array_types) | set(expected_non_array_types) + assert covered_types == set(SegmentType), "All SegmentType values should be covered in tests" + + def test_all_enum_values_are_supported(self): + """ + Test that all enum values are supported and return boolean values. + + Validates that every SegmentType enum value can be processed by + is_array_type method and returns a boolean value. + """ + enum_values: list[SegmentType] = list(SegmentType) + for seg_type in enum_values: + is_array = seg_type.is_array_type() + assert isinstance(is_array, bool), f"is_array_type does not return a boolean for segment type {seg_type}" diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py index 426557c71..925142892 100644 --- a/api/tests/unit_tests/core/variables/test_variables.py +++ b/api/tests/unit_tests/core/variables/test_variables.py @@ -11,6 +11,7 @@ from core.variables import ( SegmentType, StringVariable, ) +from core.variables.variables import Variable def test_frozen_variables(): @@ -75,7 +76,7 @@ def test_object_variable_to_object(): def test_variable_to_object(): - var = StringVariable(name="text", value="text") + var: Variable = StringVariable(name="text", value="text") assert var.to_object() == "text" var = IntegerVariable(name="integer", value=42) assert var.to_object() == 42 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py new file mode 100644 index 000000000..cf7cee871 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py @@ -0,0 +1,146 @@ +import time +from decimal import Decimal + +from core.model_runtime.entities.llm_entities import LLMUsage +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState +from core.workflow.system_variable import SystemVariable + + +def create_test_graph_runtime_state() -> GraphRuntimeState: + """Factory function to create a GraphRuntimeState with non-empty values for testing.""" + # Create a variable pool with system variables + system_vars = SystemVariable( + user_id="test_user_123", + app_id="test_app_456", + workflow_id="test_workflow_789", + workflow_execution_id="test_execution_001", + query="test query", + conversation_id="test_conv_123", + dialogue_count=5, + ) + variable_pool = VariablePool(system_variables=system_vars) + + # Add some variables to the variable pool + variable_pool.add(["test_node", "test_var"], "test_value") + variable_pool.add(["another_node", "another_var"], 42) + + # Create LLM usage with realistic values + llm_usage = LLMUsage( + prompt_tokens=150, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal(1000), + prompt_price=Decimal("0.15"), + completion_tokens=75, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal(1000), + completion_price=Decimal("0.15"), + total_tokens=225, + total_price=Decimal("0.30"), + currency="USD", + latency=1.25, + ) + + # Create runtime route state with some node states + node_run_state = RuntimeRouteState() + node_state = node_run_state.create_node_state("test_node_1") + node_run_state.add_route(node_state.id, "target_node_id") + + return GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), + total_tokens=100, + llm_usage=llm_usage, + outputs={ + "string_output": "test result", + "int_output": 42, + "float_output": 3.14, + "list_output": ["item1", "item2", "item3"], + "dict_output": {"key1": "value1", "key2": 123}, + "nested_dict": {"level1": {"level2": ["nested", "list", 456]}}, + }, + node_run_steps=5, + node_run_state=node_run_state, + ) + + +def test_basic_round_trip_serialization(): + """Test basic round-trip serialization ensures GraphRuntimeState values remain unchanged.""" + # Create a state with non-empty values + original_state = create_test_graph_runtime_state() + + # Serialize to JSON and deserialize back + json_data = original_state.model_dump_json() + deserialized_state = GraphRuntimeState.model_validate_json(json_data) + + # Core test: ensure the round-trip preserves all values + assert deserialized_state == original_state + + # Serialize to JSON and deserialize back + dict_data = original_state.model_dump(mode="python") + deserialized_state = GraphRuntimeState.model_validate(dict_data) + assert deserialized_state == original_state + + # Serialize to JSON and deserialize back + dict_data = original_state.model_dump(mode="json") + deserialized_state = GraphRuntimeState.model_validate(dict_data) + assert deserialized_state == original_state + + +def test_outputs_field_round_trip(): + """Test the problematic outputs field maintains values through round-trip serialization.""" + original_state = create_test_graph_runtime_state() + + # Serialize and deserialize + json_data = original_state.model_dump_json() + deserialized_state = GraphRuntimeState.model_validate_json(json_data) + + # Verify the outputs field specifically maintains its values + assert deserialized_state.outputs == original_state.outputs + assert deserialized_state == original_state + + +def test_empty_outputs_round_trip(): + """Test round-trip serialization with empty outputs field.""" + variable_pool = VariablePool.empty() + original_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), + outputs={}, # Empty outputs + ) + + json_data = original_state.model_dump_json() + deserialized_state = GraphRuntimeState.model_validate_json(json_data) + + assert deserialized_state == original_state + + +def test_llm_usage_round_trip(): + # Create LLM usage with specific decimal values + llm_usage = LLMUsage( + prompt_tokens=100, + prompt_unit_price=Decimal("0.0015"), + prompt_price_unit=Decimal(1000), + prompt_price=Decimal("0.15"), + completion_tokens=50, + completion_unit_price=Decimal("0.003"), + completion_price_unit=Decimal(1000), + completion_price=Decimal("0.15"), + total_tokens=150, + total_price=Decimal("0.30"), + currency="USD", + latency=2.5, + ) + + json_data = llm_usage.model_dump_json() + deserialized = LLMUsage.model_validate_json(json_data) + assert deserialized == llm_usage + + dict_data = llm_usage.model_dump(mode="python") + deserialized = LLMUsage.model_validate(dict_data) + assert deserialized == llm_usage + + dict_data = llm_usage.model_dump(mode="json") + deserialized = LLMUsage.model_validate(dict_data) + assert deserialized == llm_usage diff --git a/api/tests/unit_tests/core/workflow/graph_engine/entities/test_node_run_state.py b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_node_run_state.py new file mode 100644 index 000000000..f3de42479 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_node_run_state.py @@ -0,0 +1,401 @@ +import json +import uuid +from datetime import UTC, datetime + +import pytest +from pydantic import ValidationError + +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState, RuntimeRouteState + +_TEST_DATETIME = datetime(2024, 1, 15, 10, 30, 45) + + +class TestRouteNodeStateSerialization: + """Test cases for RouteNodeState Pydantic serialization/deserialization.""" + + def _test_route_node_state(self): + """Test comprehensive RouteNodeState serialization with all core fields validation.""" + + node_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"input_key": "input_value"}, + outputs={"output_key": "output_value"}, + ) + + node_state = RouteNodeState( + node_id="comprehensive_test_node", + start_at=_TEST_DATETIME, + finished_at=_TEST_DATETIME, + status=RouteNodeState.Status.SUCCESS, + node_run_result=node_run_result, + index=5, + paused_at=_TEST_DATETIME, + paused_by="user_123", + failed_reason="test_reason", + ) + return node_state + + def test_route_node_state_comprehensive_field_validation(self): + """Test comprehensive RouteNodeState serialization with all core fields validation.""" + node_state = self._test_route_node_state() + serialized = node_state.model_dump() + + # Comprehensive validation of all RouteNodeState fields + assert serialized["node_id"] == "comprehensive_test_node" + assert serialized["status"] == RouteNodeState.Status.SUCCESS + assert serialized["start_at"] == _TEST_DATETIME + assert serialized["finished_at"] == _TEST_DATETIME + assert serialized["paused_at"] == _TEST_DATETIME + assert serialized["paused_by"] == "user_123" + assert serialized["failed_reason"] == "test_reason" + assert serialized["index"] == 5 + assert "id" in serialized + assert isinstance(serialized["id"], str) + uuid.UUID(serialized["id"]) # Validate UUID format + + # Validate nested NodeRunResult structure + assert serialized["node_run_result"] is not None + assert serialized["node_run_result"]["status"] == WorkflowNodeExecutionStatus.SUCCEEDED + assert serialized["node_run_result"]["inputs"] == {"input_key": "input_value"} + assert serialized["node_run_result"]["outputs"] == {"output_key": "output_value"} + + def test_route_node_state_minimal_required_fields(self): + """Test RouteNodeState with only required fields, focusing on defaults.""" + node_state = RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME) + + serialized = node_state.model_dump() + + # Focus on required fields and default values (not re-testing all fields) + assert serialized["node_id"] == "minimal_node" + assert serialized["start_at"] == _TEST_DATETIME + assert serialized["status"] == RouteNodeState.Status.RUNNING # Default status + assert serialized["index"] == 1 # Default index + assert serialized["node_run_result"] is None # Default None + json = node_state.model_dump_json() + deserialized = RouteNodeState.model_validate_json(json) + assert deserialized == node_state + + def test_route_node_state_deserialization_from_dict(self): + """Test RouteNodeState deserialization from dictionary data.""" + test_datetime = datetime(2024, 1, 15, 10, 30, 45) + test_id = str(uuid.uuid4()) + + dict_data = { + "id": test_id, + "node_id": "deserialized_node", + "start_at": test_datetime, + "status": "success", + "finished_at": test_datetime, + "index": 3, + } + + node_state = RouteNodeState.model_validate(dict_data) + + # Focus on deserialization accuracy + assert node_state.id == test_id + assert node_state.node_id == "deserialized_node" + assert node_state.start_at == test_datetime + assert node_state.status == RouteNodeState.Status.SUCCESS + assert node_state.finished_at == test_datetime + assert node_state.index == 3 + + def test_route_node_state_round_trip_consistency(self): + node_states = ( + self._test_route_node_state(), + RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME), + ) + for node_state in node_states: + json = node_state.model_dump_json() + deserialized = RouteNodeState.model_validate_json(json) + assert deserialized == node_state + + dict_ = node_state.model_dump(mode="python") + deserialized = RouteNodeState.model_validate(dict_) + assert deserialized == node_state + + dict_ = node_state.model_dump(mode="json") + deserialized = RouteNodeState.model_validate(dict_) + assert deserialized == node_state + + +class TestRouteNodeStateEnumSerialization: + """Dedicated tests for RouteNodeState Status enum serialization behavior.""" + + def test_status_enum_model_dump_behavior(self): + """Test Status enum serialization in model_dump() returns enum objects.""" + + for status_enum in RouteNodeState.Status: + node_state = RouteNodeState(node_id="enum_test", start_at=_TEST_DATETIME, status=status_enum) + serialized = node_state.model_dump(mode="python") + assert serialized["status"] == status_enum + serialized = node_state.model_dump(mode="json") + assert serialized["status"] == status_enum.value + + def test_status_enum_json_serialization_behavior(self): + """Test Status enum serialization in JSON returns string values.""" + test_datetime = datetime(2024, 1, 15, 10, 30, 45) + + enum_to_string_mapping = { + RouteNodeState.Status.RUNNING: "running", + RouteNodeState.Status.SUCCESS: "success", + RouteNodeState.Status.FAILED: "failed", + RouteNodeState.Status.PAUSED: "paused", + RouteNodeState.Status.EXCEPTION: "exception", + } + + for status_enum, expected_string in enum_to_string_mapping.items(): + node_state = RouteNodeState(node_id="json_enum_test", start_at=test_datetime, status=status_enum) + + json_data = json.loads(node_state.model_dump_json()) + assert json_data["status"] == expected_string + + def test_status_enum_deserialization_from_string(self): + """Test Status enum deserialization from string values.""" + test_datetime = datetime(2024, 1, 15, 10, 30, 45) + + string_to_enum_mapping = { + "running": RouteNodeState.Status.RUNNING, + "success": RouteNodeState.Status.SUCCESS, + "failed": RouteNodeState.Status.FAILED, + "paused": RouteNodeState.Status.PAUSED, + "exception": RouteNodeState.Status.EXCEPTION, + } + + for status_string, expected_enum in string_to_enum_mapping.items(): + dict_data = { + "node_id": "enum_deserialize_test", + "start_at": test_datetime, + "status": status_string, + } + + node_state = RouteNodeState.model_validate(dict_data) + assert node_state.status == expected_enum + + +class TestRuntimeRouteStateSerialization: + """Test cases for RuntimeRouteState Pydantic serialization/deserialization.""" + + _NODE1_ID = "node_1" + _ROUTE_STATE1_ID = str(uuid.uuid4()) + _NODE2_ID = "node_2" + _ROUTE_STATE2_ID = str(uuid.uuid4()) + _NODE3_ID = "node_3" + _ROUTE_STATE3_ID = str(uuid.uuid4()) + + def _get_runtime_route_state(self): + # Create node states with different configurations + node_state_1 = RouteNodeState( + id=self._ROUTE_STATE1_ID, + node_id=self._NODE1_ID, + start_at=_TEST_DATETIME, + index=1, + ) + node_state_2 = RouteNodeState( + id=self._ROUTE_STATE2_ID, + node_id=self._NODE2_ID, + start_at=_TEST_DATETIME, + status=RouteNodeState.Status.SUCCESS, + finished_at=_TEST_DATETIME, + index=2, + ) + node_state_3 = RouteNodeState( + id=self._ROUTE_STATE3_ID, + node_id=self._NODE3_ID, + start_at=_TEST_DATETIME, + status=RouteNodeState.Status.FAILED, + failed_reason="Test failure", + index=3, + ) + + runtime_state = RuntimeRouteState( + routes={node_state_1.id: [node_state_2.id, node_state_3.id], node_state_2.id: [node_state_3.id]}, + node_state_mapping={ + node_state_1.id: node_state_1, + node_state_2.id: node_state_2, + node_state_3.id: node_state_3, + }, + ) + + return runtime_state + + def test_runtime_route_state_comprehensive_structure_validation(self): + """Test comprehensive RuntimeRouteState serialization with full structure validation.""" + + runtime_state = self._get_runtime_route_state() + serialized = runtime_state.model_dump() + + # Comprehensive validation of RuntimeRouteState structure + assert "routes" in serialized + assert "node_state_mapping" in serialized + assert isinstance(serialized["routes"], dict) + assert isinstance(serialized["node_state_mapping"], dict) + + # Validate routes dictionary structure and content + assert len(serialized["routes"]) == 2 + assert self._ROUTE_STATE1_ID in serialized["routes"] + assert self._ROUTE_STATE2_ID in serialized["routes"] + assert serialized["routes"][self._ROUTE_STATE1_ID] == [self._ROUTE_STATE2_ID, self._ROUTE_STATE3_ID] + assert serialized["routes"][self._ROUTE_STATE2_ID] == [self._ROUTE_STATE3_ID] + + # Validate node_state_mapping dictionary structure and content + assert len(serialized["node_state_mapping"]) == 3 + for state_id in [ + self._ROUTE_STATE1_ID, + self._ROUTE_STATE2_ID, + self._ROUTE_STATE3_ID, + ]: + assert state_id in serialized["node_state_mapping"] + node_data = serialized["node_state_mapping"][state_id] + node_state = runtime_state.node_state_mapping[state_id] + assert node_data["node_id"] == node_state.node_id + assert node_data["status"] == node_state.status + assert node_data["index"] == node_state.index + + def test_runtime_route_state_empty_collections(self): + """Test RuntimeRouteState with empty collections, focusing on default behavior.""" + runtime_state = RuntimeRouteState() + serialized = runtime_state.model_dump() + + # Focus on default empty collection behavior + assert serialized["routes"] == {} + assert serialized["node_state_mapping"] == {} + assert isinstance(serialized["routes"], dict) + assert isinstance(serialized["node_state_mapping"], dict) + + def test_runtime_route_state_json_serialization_structure(self): + """Test RuntimeRouteState JSON serialization structure.""" + node_state = RouteNodeState(node_id="json_node", start_at=_TEST_DATETIME) + + runtime_state = RuntimeRouteState( + routes={"source": ["target1", "target2"]}, node_state_mapping={node_state.id: node_state} + ) + + json_str = runtime_state.model_dump_json() + json_data = json.loads(json_str) + + # Focus on JSON structure validation + assert isinstance(json_str, str) + assert isinstance(json_data, dict) + assert "routes" in json_data + assert "node_state_mapping" in json_data + assert json_data["routes"]["source"] == ["target1", "target2"] + assert node_state.id in json_data["node_state_mapping"] + + def test_runtime_route_state_deserialization_from_dict(self): + """Test RuntimeRouteState deserialization from dictionary data.""" + node_id = str(uuid.uuid4()) + + dict_data = { + "routes": {"source_node": ["target_node_1", "target_node_2"]}, + "node_state_mapping": { + node_id: { + "id": node_id, + "node_id": "test_node", + "start_at": _TEST_DATETIME, + "status": "running", + "index": 1, + } + }, + } + + runtime_state = RuntimeRouteState.model_validate(dict_data) + + # Focus on deserialization accuracy + assert runtime_state.routes == {"source_node": ["target_node_1", "target_node_2"]} + assert len(runtime_state.node_state_mapping) == 1 + assert node_id in runtime_state.node_state_mapping + + deserialized_node = runtime_state.node_state_mapping[node_id] + assert deserialized_node.node_id == "test_node" + assert deserialized_node.status == RouteNodeState.Status.RUNNING + assert deserialized_node.index == 1 + + def test_runtime_route_state_round_trip_consistency(self): + """Test RuntimeRouteState round-trip serialization consistency.""" + original = self._get_runtime_route_state() + + # Dictionary round trip + dict_data = original.model_dump(mode="python") + reconstructed = RuntimeRouteState.model_validate(dict_data) + assert reconstructed == original + + dict_data = original.model_dump(mode="json") + reconstructed = RuntimeRouteState.model_validate(dict_data) + assert reconstructed == original + + # JSON round trip + json_str = original.model_dump_json() + json_reconstructed = RuntimeRouteState.model_validate_json(json_str) + assert json_reconstructed == original + + +class TestSerializationEdgeCases: + """Test edge cases and error conditions for serialization/deserialization.""" + + def test_invalid_status_deserialization(self): + """Test deserialization with invalid status values.""" + test_datetime = _TEST_DATETIME + invalid_data = { + "node_id": "invalid_test", + "start_at": test_datetime, + "status": "invalid_status", + } + + with pytest.raises(ValidationError) as exc_info: + RouteNodeState.model_validate(invalid_data) + assert "status" in str(exc_info.value) + + def test_missing_required_fields_deserialization(self): + """Test deserialization with missing required fields.""" + incomplete_data = {"id": str(uuid.uuid4())} + + with pytest.raises(ValidationError) as exc_info: + RouteNodeState.model_validate(incomplete_data) + error_str = str(exc_info.value) + assert "node_id" in error_str or "start_at" in error_str + + def test_invalid_datetime_deserialization(self): + """Test deserialization with invalid datetime values.""" + invalid_data = { + "node_id": "datetime_test", + "start_at": "invalid_datetime", + "status": "running", + } + + with pytest.raises(ValidationError) as exc_info: + RouteNodeState.model_validate(invalid_data) + assert "start_at" in str(exc_info.value) + + def test_invalid_routes_structure_deserialization(self): + """Test RuntimeRouteState deserialization with invalid routes structure.""" + invalid_data = { + "routes": "invalid_routes_structure", # Should be dict + "node_state_mapping": {}, + } + + with pytest.raises(ValidationError) as exc_info: + RuntimeRouteState.model_validate(invalid_data) + assert "routes" in str(exc_info.value) + + def test_timezone_handling_in_datetime_fields(self): + """Test timezone handling in datetime field serialization.""" + utc_datetime = datetime.now(UTC) + naive_datetime = utc_datetime.replace(tzinfo=None) + + node_state = RouteNodeState(node_id="timezone_test", start_at=naive_datetime) + dict_ = node_state.model_dump() + + assert dict_["start_at"] == naive_datetime + + # Test round trip + reconstructed = RouteNodeState.model_validate(dict_) + assert reconstructed.start_at == naive_datetime + assert reconstructed.start_at.tzinfo is None + + json = node_state.model_dump_json() + + reconstructed = RouteNodeState.model_validate_json(json) + assert reconstructed.start_at == naive_datetime + assert reconstructed.start_at.tzinfo is None diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index c288a5fa1..ed4e42425 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -8,7 +8,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import ( BaseNodeEvent, GraphRunFailedEvent, @@ -27,6 +26,7 @@ from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.nodes.llm.node import LLMNode from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode +from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.workflow import WorkflowType @@ -171,7 +171,8 @@ def test_run_parallel_in_workflow(mock_close, mock_remove): graph = Graph.init(graph_config=graph_config) variable_pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"} + system_variables=SystemVariable(user_id="aaa", app_id="1", workflow_id="1", files=[]), + user_inputs={"query": "hi"}, ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) @@ -293,12 +294,12 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove): graph = Graph.init(graph_config=graph_config) variable_pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather in SF", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, + system_variables=SystemVariable( + user_id="aaa", + files=[], + query="what's the weather in SF", + conversation_id="abababa", + ), user_inputs={}, ) @@ -474,12 +475,12 @@ def test_run_branch(mock_close, mock_remove): graph = Graph.init(graph_config=graph_config) variable_pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "hi", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, + system_variables=SystemVariable( + user_id="aaa", + files=[], + query="hi", + conversation_id="abababa", + ), user_inputs={"uid": "takato"}, ) @@ -804,18 +805,22 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app): # construct variable pool pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "dify", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "1", - }, + system_variables=SystemVariable( + user_id="1", + files=[], + query="dify", + conversation_id="abababa", + ), user_inputs={}, environment_variables=[], ) pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) variable_pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"} + system_variables=SystemVariable( + user_id="aaa", + files=[], + ), + user_inputs={"query": "hi"}, ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index b7f78d91f..85ff4f9c0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -5,11 +5,11 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.system_variable import SystemVariable from extensions.ext_database import db from models.enums import UserFrom from models.workflow import WorkflowType @@ -51,7 +51,7 @@ def test_execute_answer(): # construct variable pool pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], ) diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py index c3a381865..137e8b889 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py @@ -3,7 +3,6 @@ from collections.abc import Generator from datetime import UTC, datetime from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, NodeRunStartedEvent, @@ -15,6 +14,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor from core.workflow.nodes.enums import NodeType from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.system_variable import SystemVariable def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]: @@ -180,12 +180,12 @@ def test_process(): graph = Graph.init(graph_config=graph_config) variable_pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather in SF", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, + system_variables=SystemVariable( + user_id="aaa", + files=[], + query="what's the weather in SF", + conversation_id="abababa", + ), user_inputs={}, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index d066fc1e3..bb6d72f51 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -7,12 +7,13 @@ from core.workflow.nodes.http_request import ( ) from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout from core.workflow.nodes.http_request.executor import Executor +from core.workflow.system_variable import SystemVariable def test_executor_with_json_body_and_number_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, ) variable_pool.add(["pre_node_id", "number"], 42) @@ -65,7 +66,7 @@ def test_executor_with_json_body_and_number_variable(): def test_executor_with_json_body_and_object_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, ) variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) @@ -120,7 +121,7 @@ def test_executor_with_json_body_and_object_variable(): def test_executor_with_json_body_and_nested_object_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, ) variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) @@ -174,7 +175,7 @@ def test_executor_with_json_body_and_nested_object_variable(): def test_extract_selectors_from_template_with_newline(): - variable_pool = VariablePool() + variable_pool = VariablePool(system_variables=SystemVariable.empty()) variable_pool.add(("node_id", "custom_query"), "line1\nline2") node_data = HttpRequestNodeData( title="Test JSON Body with Nested Object Variable", @@ -201,7 +202,7 @@ def test_extract_selectors_from_template_with_newline(): def test_executor_with_form_data(): # Prepare the variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, ) variable_pool.add(["pre_node_id", "text_field"], "Hello, World!") @@ -280,7 +281,11 @@ def test_init_headers(): authorization=HttpRequestNodeAuthorization(type="no-auth"), ) timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) - return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool()) + return Executor( + node_data=node_data, + timeout=timeout, + variable_pool=VariablePool(system_variables=SystemVariable.empty()), + ) executor = create_executor("aa\n cc:") executor._init_headers() @@ -310,7 +315,11 @@ def test_init_params(): authorization=HttpRequestNodeAuthorization(type="no-auth"), ) timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) - return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool()) + return Executor( + node_data=node_data, + timeout=timeout, + variable_pool=VariablePool(system_variables=SystemVariable.empty()), + ) # Test basic key-value pairs executor = create_executor("key1:value1\nkey2:value2") diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 7fd32a482..33f9251a7 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -15,6 +15,7 @@ from core.workflow.nodes.http_request import ( HttpRequestNodeBody, HttpRequestNodeData, ) +from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.workflow import WorkflowType @@ -40,7 +41,7 @@ def test_http_request_node_binary_file(monkeypatch): ), ) variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, ) variable_pool.add( @@ -128,7 +129,7 @@ def test_http_request_node_form_with_file(monkeypatch): ), ) variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, ) variable_pool.add( @@ -223,7 +224,7 @@ def test_http_request_node_form_with_multiple_files(monkeypatch): ) variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py index 362072a3d..17c23b773 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -7,7 +7,6 @@ from core.variables.segments import ArrayAnySegment, ArrayStringSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState @@ -15,6 +14,7 @@ from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.iteration.entities import ErrorHandleMode from core.workflow.nodes.iteration.iteration_node import IterationNode from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.workflow import WorkflowType @@ -151,12 +151,12 @@ def test_run(): # construct variable pool pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "dify", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "1", - }, + system_variables=SystemVariable( + user_id="1", + files=[], + query="dify", + conversation_id="abababa", + ), user_inputs={}, environment_variables=[], ) @@ -368,12 +368,12 @@ def test_run_parallel(): # construct variable pool pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "dify", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "1", - }, + system_variables=SystemVariable( + user_id="1", + files=[], + query="dify", + conversation_id="abababa", + ), user_inputs={}, environment_variables=[], ) @@ -584,12 +584,12 @@ def test_iteration_run_in_parallel_mode(): # construct variable pool pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "dify", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "1", - }, + system_variables=SystemVariable( + user_id="1", + files=[], + query="dify", + conversation_id="abababa", + ), user_inputs={}, environment_variables=[], ) @@ -808,12 +808,12 @@ def test_iteration_run_error_handle(): # construct variable pool pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "dify", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "1", - }, + system_variables=SystemVariable( + user_id="1", + files=[], + query="dify", + conversation_id="abababa", + ), user_inputs={}, environment_variables=[], ) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 336c2befc..fefad0ec9 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -36,6 +36,7 @@ from core.workflow.nodes.llm.entities import ( ) from core.workflow.nodes.llm.file_saver import LLMFileSaver from core.workflow.nodes.llm.node import LLMNode +from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.provider import ProviderType from models.workflow import WorkflowType @@ -104,7 +105,7 @@ def graph() -> Graph: @pytest.fixture def graph_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, ) return GraphRuntimeState( @@ -181,7 +182,7 @@ def test_fetch_files_with_file_segment(): related_id="1", storage_key="", ) - variable_pool = VariablePool() + variable_pool = VariablePool.empty() variable_pool.add(["sys", "files"], file) result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) @@ -209,7 +210,7 @@ def test_fetch_files_with_array_file_segment(): storage_key="", ), ] - variable_pool = VariablePool() + variable_pool = VariablePool.empty() variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) @@ -217,7 +218,7 @@ def test_fetch_files_with_array_file_segment(): def test_fetch_files_with_none_segment(): - variable_pool = VariablePool() + variable_pool = VariablePool.empty() variable_pool.add(["sys", "files"], NoneSegment()) result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) @@ -225,7 +226,7 @@ def test_fetch_files_with_none_segment(): def test_fetch_files_with_array_any_segment(): - variable_pool = VariablePool() + variable_pool = VariablePool.empty() variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) @@ -233,7 +234,7 @@ def test_fetch_files_with_array_any_segment(): def test_fetch_files_with_non_existent_variable(): - variable_pool = VariablePool() + variable_pool = VariablePool.empty() result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) assert result == [] diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index abc822e98..44c31b212 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -5,11 +5,11 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.system_variable import SystemVariable from extensions.ext_database import db from models.enums import UserFrom from models.workflow import WorkflowType @@ -53,7 +53,7 @@ def test_execute_answer(): # construct variable pool variable_pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py index a6c553faf..3f8342883 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -5,7 +5,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import ( GraphRunPartialSucceededEvent, NodeRunExceptionEvent, @@ -17,6 +16,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.nodes.llm.node import LLMNode +from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.workflow import WorkflowType @@ -167,12 +167,12 @@ class ContinueOnErrorTestHelper: """Helper method to create a graph engine instance for testing""" graph = Graph.init(graph_config=graph_config) variable_pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "clear", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, + system_variables=SystemVariable( + user_id="aaa", + files=[], + query="clear", + conversation_id="abababa", + ), user_inputs=user_inputs or {"uid": "takato"}, ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index c4e411f9d..167a92484 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -7,12 +7,12 @@ from core.file import File, FileTransferMethod, FileType from core.variables import ArrayFileSegment from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.if_else.entities import IfElseNodeData from core.workflow.nodes.if_else.if_else_node import IfElseNode +from core.workflow.system_variable import SystemVariable from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition from extensions.ext_database import db from models.enums import UserFrom @@ -37,9 +37,7 @@ def test_execute_if_else_result_true(): ) # construct variable pool - pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={} - ) + pool = VariablePool(system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={}) pool.add(["start", "array_contains"], ["ab", "def"]) pool.add(["start", "array_not_contains"], ["ac", "def"]) pool.add(["start", "contains"], "cabcde") @@ -157,7 +155,7 @@ def test_execute_if_else_result_false(): # construct variable pool pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], ) diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index e121f6338..2776e5777 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -15,6 +15,7 @@ from core.workflow.nodes.enums import ErrorStrategy from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.tool import ToolNode from core.workflow.nodes.tool.entities import ToolNodeData +from core.workflow.system_variable import SystemVariable from models import UserFrom, WorkflowType @@ -34,7 +35,7 @@ def _create_tool_node(): version="1", ) variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, ) node = ToolNode( diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index deb3e29b8..62e3e3710 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -7,12 +7,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import ArrayStringVariable, StringVariable from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode +from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.workflow import WorkflowType @@ -68,7 +68,7 @@ def test_overwrite_string_variable(): # construct variable pool variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id}, + system_variables=SystemVariable(conversation_id=conversation_id), user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -165,7 +165,7 @@ def test_append_variable_to_array(): conversation_id = str(uuid.uuid4()) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id}, + system_variables=SystemVariable(conversation_id=conversation_id), user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -256,7 +256,7 @@ def test_clear_array(): conversation_id = str(uuid.uuid4()) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id}, + system_variables=SystemVariable(conversation_id=conversation_id), user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index 7c5597dd8..a3a90b059 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -5,12 +5,12 @@ from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import ArrayStringVariable from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation +from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.workflow import WorkflowType @@ -109,7 +109,7 @@ def test_remove_first_from_array(): ) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + system_variables=SystemVariable(conversation_id="conversation_id"), user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -196,7 +196,7 @@ def test_remove_last_from_array(): ) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + system_variables=SystemVariable(conversation_id="conversation_id"), user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -275,7 +275,7 @@ def test_remove_first_from_empty_array(): ) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + system_variables=SystemVariable(conversation_id="conversation_id"), user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -354,7 +354,7 @@ def test_remove_last_from_empty_array(): ) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + system_variables=SystemVariable(conversation_id="conversation_id"), user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], diff --git a/api/tests/unit_tests/core/workflow/test_system_variable.py b/api/tests/unit_tests/core/workflow/test_system_variable.py new file mode 100644 index 000000000..11d788ed7 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_system_variable.py @@ -0,0 +1,251 @@ +import json +from typing import Any + +import pytest +from pydantic import ValidationError + +from core.file.enums import FileTransferMethod, FileType +from core.file.models import File +from core.workflow.system_variable import SystemVariable + +# Test data constants for SystemVariable serialization tests +VALID_BASE_DATA: dict[str, Any] = { + "user_id": "a20f06b1-8703-45ab-937c-860a60072113", + "app_id": "661bed75-458d-49c9-b487-fda0762677b9", + "workflow_id": "d31f2136-b292-4ae0-96d4-1e77894a4f43", +} + +COMPLETE_VALID_DATA: dict[str, Any] = { + **VALID_BASE_DATA, + "query": "test query", + "files": [], + "conversation_id": "91f1eb7d-69f4-4d7b-b82f-4003d51744b9", + "dialogue_count": 5, + "workflow_run_id": "eb4704b5-2274-47f2-bfcd-0452daa82cb5", +} + + +def create_test_file() -> File: + """Create a test File object for serialization tests.""" + return File( + tenant_id="test-tenant-id", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="test-file-id", + filename="test.txt", + extension=".txt", + mime_type="text/plain", + size=1024, + storage_key="test-storage-key", + ) + + +class TestSystemVariableSerialization: + """Focused tests for SystemVariable serialization/deserialization logic.""" + + def test_basic_deserialization(self): + """Test successful deserialization from JSON structure with all fields correctly mapped.""" + # Test with complete data + system_var = SystemVariable(**COMPLETE_VALID_DATA) + + # Verify all fields are correctly mapped + assert system_var.user_id == COMPLETE_VALID_DATA["user_id"] + assert system_var.app_id == COMPLETE_VALID_DATA["app_id"] + assert system_var.workflow_id == COMPLETE_VALID_DATA["workflow_id"] + assert system_var.query == COMPLETE_VALID_DATA["query"] + assert system_var.conversation_id == COMPLETE_VALID_DATA["conversation_id"] + assert system_var.dialogue_count == COMPLETE_VALID_DATA["dialogue_count"] + assert system_var.workflow_execution_id == COMPLETE_VALID_DATA["workflow_run_id"] + assert system_var.files == [] + + # Test with minimal data (only required fields) + minimal_var = SystemVariable(**VALID_BASE_DATA) + assert minimal_var.user_id == VALID_BASE_DATA["user_id"] + assert minimal_var.app_id == VALID_BASE_DATA["app_id"] + assert minimal_var.workflow_id == VALID_BASE_DATA["workflow_id"] + assert minimal_var.query is None + assert minimal_var.conversation_id is None + assert minimal_var.dialogue_count is None + assert minimal_var.workflow_execution_id is None + assert minimal_var.files == [] + + def test_alias_handling(self): + """Test workflow_execution_id vs workflow_run_id alias resolution - core deserialization logic.""" + workflow_id = "eb4704b5-2274-47f2-bfcd-0452daa82cb5" + + # Test workflow_run_id only (preferred alias) + data_run_id = {**VALID_BASE_DATA, "workflow_run_id": workflow_id} + system_var1 = SystemVariable(**data_run_id) + assert system_var1.workflow_execution_id == workflow_id + + # Test workflow_execution_id only (direct field name) + data_execution_id = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id} + system_var2 = SystemVariable(**data_execution_id) + assert system_var2.workflow_execution_id == workflow_id + + # Test both present - workflow_run_id should take precedence + data_both = { + **VALID_BASE_DATA, + "workflow_execution_id": "should-be-ignored", + "workflow_run_id": workflow_id, + } + system_var3 = SystemVariable(**data_both) + assert system_var3.workflow_execution_id == workflow_id + + # Test neither present - should be None + system_var4 = SystemVariable(**VALID_BASE_DATA) + assert system_var4.workflow_execution_id is None + + def test_serialization_round_trip(self): + """Test that serialize → deserialize produces the same result with alias handling.""" + # Create original SystemVariable + original = SystemVariable(**COMPLETE_VALID_DATA) + + # Serialize to dict + serialized = original.model_dump(mode="json") + + # Verify alias is used in serialization (workflow_run_id, not workflow_execution_id) + assert "workflow_run_id" in serialized + assert "workflow_execution_id" not in serialized + assert serialized["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"] + + # Deserialize back + deserialized = SystemVariable(**serialized) + + # Verify all fields match after round-trip + assert deserialized.user_id == original.user_id + assert deserialized.app_id == original.app_id + assert deserialized.workflow_id == original.workflow_id + assert deserialized.query == original.query + assert deserialized.conversation_id == original.conversation_id + assert deserialized.dialogue_count == original.dialogue_count + assert deserialized.workflow_execution_id == original.workflow_execution_id + assert list(deserialized.files) == list(original.files) + + def test_json_round_trip(self): + """Test JSON serialization/deserialization consistency with proper structure.""" + # Create original SystemVariable + original = SystemVariable(**COMPLETE_VALID_DATA) + + # Serialize to JSON string + json_str = original.model_dump_json() + + # Parse JSON and verify structure + json_data = json.loads(json_str) + assert "workflow_run_id" in json_data + assert "workflow_execution_id" not in json_data + assert json_data["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"] + + # Deserialize from JSON data + deserialized = SystemVariable(**json_data) + + # Verify key fields match after JSON round-trip + assert deserialized.workflow_execution_id == original.workflow_execution_id + assert deserialized.user_id == original.user_id + assert deserialized.app_id == original.app_id + assert deserialized.workflow_id == original.workflow_id + + def test_files_field_deserialization(self): + """Test deserialization with File objects in the files field - SystemVariable specific logic.""" + # Test with empty files list + data_empty = {**VALID_BASE_DATA, "files": []} + system_var_empty = SystemVariable(**data_empty) + assert system_var_empty.files == [] + + # Test with single File object + test_file = create_test_file() + data_single = {**VALID_BASE_DATA, "files": [test_file]} + system_var_single = SystemVariable(**data_single) + assert len(system_var_single.files) == 1 + assert system_var_single.files[0].filename == "test.txt" + assert system_var_single.files[0].tenant_id == "test-tenant-id" + + # Test with multiple File objects + file1 = File( + tenant_id="tenant1", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="file1", + filename="doc1.txt", + storage_key="key1", + ) + file2 = File( + tenant_id="tenant2", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image.jpg", + filename="image.jpg", + storage_key="key2", + ) + + data_multiple = {**VALID_BASE_DATA, "files": [file1, file2]} + system_var_multiple = SystemVariable(**data_multiple) + assert len(system_var_multiple.files) == 2 + assert system_var_multiple.files[0].filename == "doc1.txt" + assert system_var_multiple.files[1].filename == "image.jpg" + + # Verify files field serialization/deserialization + serialized = system_var_multiple.model_dump(mode="json") + deserialized = SystemVariable(**serialized) + assert len(deserialized.files) == 2 + assert deserialized.files[0].filename == "doc1.txt" + assert deserialized.files[1].filename == "image.jpg" + + def test_alias_serialization_consistency(self): + """Test that alias handling works consistently in both serialization directions.""" + workflow_id = "test-workflow-id" + + # Create with workflow_run_id (alias) + data_with_alias = {**VALID_BASE_DATA, "workflow_run_id": workflow_id} + system_var = SystemVariable(**data_with_alias) + + # Serialize and verify alias is used + serialized = system_var.model_dump() + assert serialized["workflow_run_id"] == workflow_id + assert "workflow_execution_id" not in serialized + + # Deserialize and verify field mapping + deserialized = SystemVariable(**serialized) + assert deserialized.workflow_execution_id == workflow_id + + # Test JSON serialization path + json_serialized = json.loads(system_var.model_dump_json()) + assert json_serialized["workflow_run_id"] == workflow_id + assert "workflow_execution_id" not in json_serialized + + json_deserialized = SystemVariable(**json_serialized) + assert json_deserialized.workflow_execution_id == workflow_id + + def test_model_validator_serialization_logic(self): + """Test the custom model validator behavior for serialization scenarios.""" + workflow_id = "test-workflow-execution-id" + + # Test direct instantiation with workflow_execution_id (should work) + data1 = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id} + system_var1 = SystemVariable(**data1) + assert system_var1.workflow_execution_id == workflow_id + + # Test serialization of the above (should use alias) + serialized1 = system_var1.model_dump() + assert "workflow_run_id" in serialized1 + assert serialized1["workflow_run_id"] == workflow_id + + # Test both present - workflow_run_id takes precedence (validator logic) + data2 = { + **VALID_BASE_DATA, + "workflow_execution_id": "should-be-removed", + "workflow_run_id": workflow_id, + } + system_var2 = SystemVariable(**data2) + assert system_var2.workflow_execution_id == workflow_id + + # Verify serialization consistency + serialized2 = system_var2.model_dump() + assert serialized2["workflow_run_id"] == workflow_id + + +def test_constructor_with_extra_key(): + # Test that SystemVariable should forbid extra keys + with pytest.raises(ValidationError): + # This should fail because there is an unexpected key. + SystemVariable(invalid_key=1) # type: ignore diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index bb8d34fad..c65b60cb4 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -1,17 +1,43 @@ +import uuid +from collections import defaultdict + import pytest -from pydantic import ValidationError from core.file import File, FileTransferMethod, FileType from core.variables import FileSegment, StringSegment -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID +from core.variables.segments import ( + ArrayAnySegment, + ArrayFileSegment, + ArrayNumberSegment, + ArrayObjectSegment, + ArrayStringSegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, +) +from core.variables.variables import ( + ArrayNumberVariable, + ArrayObjectVariable, + ArrayStringVariable, + FloatVariable, + IntegerVariable, + ObjectVariable, + StringVariable, + VariableUnion, +) +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey +from core.workflow.system_variable import SystemVariable from factories.variable_factory import build_segment, segment_to_variable @pytest.fixture def pool(): - return VariablePool(system_variables={}, user_inputs={}) + return VariablePool( + system_variables=SystemVariable(user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"), + user_inputs={}, + ) @pytest.fixture @@ -52,18 +78,28 @@ def test_use_long_selector(pool): class TestVariablePool: def test_constructor(self): - pool = VariablePool() + # Test with minimal required SystemVariable + minimal_system_vars = SystemVariable( + user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id" + ) + pool = VariablePool(system_variables=minimal_system_vars) + + # Test with all parameters pool = VariablePool( variable_dictionary={}, user_inputs={}, - system_variables={}, + system_variables=minimal_system_vars, environment_variables=[], conversation_variables=[], ) + # Test with more complex SystemVariable + complex_system_vars = SystemVariable( + user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id" + ) pool = VariablePool( user_inputs={"key": "value"}, - system_variables={SystemVariableKey.WORKFLOW_ID: "test_workflow_id"}, + system_variables=complex_system_vars, environment_variables=[ segment_to_variable( segment=build_segment(1), @@ -80,6 +116,323 @@ class TestVariablePool: ], ) - def test_constructor_with_invalid_system_variable_key(self): - with pytest.raises(ValidationError): - VariablePool(system_variables={"invalid_key": "value"}) # type: ignore + def test_get_system_variables(self): + sys_var = SystemVariable( + user_id="test_user_id", + app_id="test_app_id", + workflow_id="test_workflow_id", + workflow_execution_id="test_execution_123", + query="test query", + conversation_id="test_conv_id", + dialogue_count=5, + ) + pool = VariablePool(system_variables=sys_var) + + kv = [ + ("user_id", sys_var.user_id), + ("app_id", sys_var.app_id), + ("workflow_id", sys_var.workflow_id), + ("workflow_run_id", sys_var.workflow_execution_id), + ("query", sys_var.query), + ("conversation_id", sys_var.conversation_id), + ("dialogue_count", sys_var.dialogue_count), + ] + for key, expected_value in kv: + segment = pool.get([SYSTEM_VARIABLE_NODE_ID, key]) + assert segment is not None + assert segment.value == expected_value + + +class TestVariablePoolSerialization: + """Test cases for VariablePool serialization and deserialization using Pydantic's built-in methods. + + These tests focus exclusively on serialization/deserialization logic to ensure that + VariablePool data can be properly serialized to dictionaries/JSON and reconstructed + while preserving all data integrity. + """ + + _NODE1_ID = "node_1" + _NODE2_ID = "node_2" + _NODE3_ID = "node_3" + + def _create_pool_without_file(self): + # Create comprehensive system variables + system_vars = SystemVariable( + user_id="test_user_id", + app_id="test_app_id", + workflow_id="test_workflow_id", + workflow_execution_id="test_execution_123", + query="test query", + conversation_id="test_conv_id", + dialogue_count=5, + ) + + # Create environment variables with all types including ArrayFileVariable + env_vars: list[VariableUnion] = [ + StringVariable( + id="env_string_id", + name="env_string", + value="env_string_value", + selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_string"], + ), + IntegerVariable( + id="env_integer_id", + name="env_integer", + value=1, + selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_integer"], + ), + FloatVariable( + id="env_float_id", + name="env_float", + value=1.0, + selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_float"], + ), + ] + + # Create conversation variables with complex data + conv_vars: list[VariableUnion] = [ + StringVariable( + id="conv_string_id", + name="conv_string", + value="conv_string_value", + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_string"], + ), + IntegerVariable( + id="conv_integer_id", + name="conv_integer", + value=1, + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_integer"], + ), + FloatVariable( + id="conv_float_id", + name="conv_float", + value=1.0, + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_float"], + ), + ObjectVariable( + id="conv_object_id", + name="conv_object", + value={"key": "value", "nested": {"data": 123}}, + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_object"], + ), + ArrayStringVariable( + id="conv_array_string_id", + name="conv_array_string", + value=["conv_array_string_value"], + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_string"], + ), + ArrayNumberVariable( + id="conv_array_number_id", + name="conv_array_number", + value=[1, 1.0], + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_number"], + ), + ArrayObjectVariable( + id="conv_array_object_id", + name="conv_array_object", + value=[{"a": 1}, {"b": "2"}], + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_object"], + ), + ] + + # Create comprehensive user inputs + user_inputs = { + "string_input": "test_value", + "number_input": 42, + "object_input": {"nested": {"key": "value"}}, + "array_input": ["item1", "item2", "item3"], + } + + # Create VariablePool + pool = VariablePool( + system_variables=system_vars, + user_inputs=user_inputs, + environment_variables=env_vars, + conversation_variables=conv_vars, + ) + return pool + + def _add_node_data_to_pool(self, pool: VariablePool, with_file=False): + test_file = File( + tenant_id="test_tenant_id", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="test_related_id", + remote_url="test_url", + filename="test_file.txt", + storage_key="test_storage_key", + ) + + # Add various segment types to variable dictionary + pool.add((self._NODE1_ID, "string_var"), StringSegment(value="test_string")) + pool.add((self._NODE1_ID, "int_var"), IntegerSegment(value=123)) + pool.add((self._NODE1_ID, "float_var"), FloatSegment(value=45.67)) + pool.add((self._NODE1_ID, "object_var"), ObjectSegment(value={"test": "data"})) + if with_file: + pool.add((self._NODE1_ID, "file_var"), FileSegment(value=test_file)) + pool.add((self._NODE1_ID, "none_var"), NoneSegment()) + + # Add array segments including ArrayFileVariable + pool.add((self._NODE2_ID, "array_string"), ArrayStringSegment(value=["a", "b", "c"])) + pool.add((self._NODE2_ID, "array_number"), ArrayNumberSegment(value=[1, 2, 3])) + pool.add((self._NODE2_ID, "array_object"), ArrayObjectSegment(value=[{"a": 1}, {"b": 2}])) + if with_file: + pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file])) + pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}])) + + # Add nested variables + pool.add((self._NODE3_ID, "nested", "deep", "var"), StringSegment(value="deep_value")) + + def test_system_variables(self): + sys_vars = SystemVariable( + user_id="test_user_id", + app_id="test_app_id", + workflow_id="test_workflow_id", + workflow_execution_id="test_execution_123", + query="test query", + conversation_id="test_conv_id", + dialogue_count=5, + ) + pool = VariablePool(system_variables=sys_vars) + json = pool.model_dump_json() + pool2 = VariablePool.model_validate_json(json) + assert pool2.system_variables == sys_vars + + for mode in ["json", "python"]: + dict_ = pool.model_dump(mode=mode) + pool2 = VariablePool.model_validate(dict_) + assert pool2.system_variables == sys_vars + + def test_pool_without_file_vars(self): + pool = self._create_pool_without_file() + json = pool.model_dump_json() + pool2 = pool.model_validate_json(json) + assert pool2.system_variables == pool.system_variables + assert pool2.conversation_variables == pool.conversation_variables + assert pool2.environment_variables == pool.environment_variables + assert pool2.user_inputs == pool.user_inputs + assert pool2.variable_dictionary == pool.variable_dictionary + assert pool2 == pool + + def test_basic_dictionary_round_trip(self): + """Test basic round-trip serialization: model_dump() → model_validate()""" + # Create a comprehensive VariablePool with all data types + original_pool = self._create_pool_without_file() + self._add_node_data_to_pool(original_pool) + + # Serialize to dictionary using Pydantic's model_dump() + serialized_data = original_pool.model_dump() + + # Verify serialized data structure + assert isinstance(serialized_data, dict) + assert "system_variables" in serialized_data + assert "user_inputs" in serialized_data + assert "environment_variables" in serialized_data + assert "conversation_variables" in serialized_data + assert "variable_dictionary" in serialized_data + + # Deserialize back using Pydantic's model_validate() + reconstructed_pool = VariablePool.model_validate(serialized_data) + + # Verify data integrity is preserved + self._assert_pools_equal(original_pool, reconstructed_pool) + + def test_json_round_trip(self): + """Test JSON round-trip serialization: model_dump_json() → model_validate_json()""" + # Create a comprehensive VariablePool with all data types + original_pool = self._create_pool_without_file() + self._add_node_data_to_pool(original_pool) + + # Serialize to JSON string using Pydantic's model_dump_json() + json_data = original_pool.model_dump_json() + + # Verify JSON is valid string + assert isinstance(json_data, str) + assert len(json_data) > 0 + + # Deserialize back using Pydantic's model_validate_json() + reconstructed_pool = VariablePool.model_validate_json(json_data) + + # Verify data integrity is preserved + self._assert_pools_equal(original_pool, reconstructed_pool) + + def test_complex_data_serialization(self): + """Test serialization of complex data structures including ArrayFileVariable""" + original_pool = self._create_pool_without_file() + self._add_node_data_to_pool(original_pool, with_file=True) + + # Test dictionary round-trip + dict_data = original_pool.model_dump() + reconstructed_dict = VariablePool.model_validate(dict_data) + + # Test JSON round-trip + json_data = original_pool.model_dump_json() + reconstructed_json = VariablePool.model_validate_json(json_data) + + # Verify both reconstructed pools are equivalent + self._assert_pools_equal(reconstructed_dict, reconstructed_json) + # TODO: assert the data for file object... + + def _assert_pools_equal(self, pool1: VariablePool, pool2: VariablePool) -> None: + """Assert that two VariablePools contain equivalent data""" + + # Compare system variables + assert pool1.system_variables == pool2.system_variables + + # Compare user inputs + assert dict(pool1.user_inputs) == dict(pool2.user_inputs) + + # Compare environment variables count + assert pool1.environment_variables == pool2.environment_variables + + # Compare conversation variables count + assert pool1.conversation_variables == pool2.conversation_variables + + # Test key variable retrievals to ensure functionality is preserved + test_selectors = [ + (SYSTEM_VARIABLE_NODE_ID, "user_id"), + (SYSTEM_VARIABLE_NODE_ID, "app_id"), + (ENVIRONMENT_VARIABLE_NODE_ID, "env_string"), + (ENVIRONMENT_VARIABLE_NODE_ID, "env_number"), + (CONVERSATION_VARIABLE_NODE_ID, "conv_string"), + (self._NODE1_ID, "string_var"), + (self._NODE1_ID, "int_var"), + (self._NODE1_ID, "float_var"), + (self._NODE2_ID, "array_string"), + (self._NODE2_ID, "array_number"), + (self._NODE3_ID, "nested", "deep", "var"), + ] + + for selector in test_selectors: + val1 = pool1.get(selector) + val2 = pool2.get(selector) + + # Both should exist or both should be None + assert (val1 is None) == (val2 is None) + + if val1 is not None and val2 is not None: + # Values should be equal + assert val1.value == val2.value + # Value types should be the same (more important than exact class type) + assert val1.value_type == val2.value_type + + def test_variable_pool_deserialization_default_dict(self): + variable_pool = VariablePool( + user_inputs={"a": 1, "b": "2"}, + system_variables=SystemVariable(workflow_id=str(uuid.uuid4())), + environment_variables=[ + StringVariable(name="str_var", value="a"), + ], + conversation_variables=[IntegerVariable(name="int_var", value=1)], + ) + assert isinstance(variable_pool.variable_dictionary, defaultdict) + json = variable_pool.model_dump_json() + loaded = VariablePool.model_validate_json(json) + assert isinstance(loaded.variable_dictionary, defaultdict) + + loaded.add(["non_exist_node", "a"], 1) + + pool_dict = variable_pool.model_dump() + loaded = VariablePool.model_validate(pool_dict) + assert isinstance(loaded.variable_dictionary, defaultdict) + loaded.add(["non_exist_node", "a"], 1) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py index 646de8bf3..642bc810b 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py @@ -18,10 +18,10 @@ from core.workflow.entities.workflow_node_execution import ( WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.system_variable import SystemVariable from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager from models.enums import CreatorUserRole from models.model import AppMode @@ -67,14 +67,14 @@ def real_app_generate_entity(): @pytest.fixture def real_workflow_system_variables(): - return { - SystemVariableKey.QUERY: "test query", - SystemVariableKey.CONVERSATION_ID: "test-conversation-id", - SystemVariableKey.USER_ID: "test-user-id", - SystemVariableKey.APP_ID: "test-app-id", - SystemVariableKey.WORKFLOW_ID: "test-workflow-id", - SystemVariableKey.WORKFLOW_EXECUTION_ID: "test-workflow-run-id", - } + return SystemVariable( + query="test query", + conversation_id="test-conversation-id", + user_id="test-user-id", + app_id="test-app-id", + workflow_id="test-workflow-id", + workflow_execution_id="test-workflow-run-id", + ) @pytest.fixture diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py b/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py index f1cb937bb..54bf6558b 100644 --- a/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py +++ b/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py @@ -10,7 +10,7 @@ class TestAppendVariablesRecursively: def test_append_simple_dict_value(self): """Test appending a simple dictionary value""" - pool = VariablePool() + pool = VariablePool.empty() node_id = "test_node" variable_key_list = ["output"] variable_value = {"name": "John", "age": 30} @@ -33,7 +33,7 @@ class TestAppendVariablesRecursively: def test_append_object_segment_value(self): """Test appending an ObjectSegment value""" - pool = VariablePool() + pool = VariablePool.empty() node_id = "test_node" variable_key_list = ["result"] @@ -60,7 +60,7 @@ class TestAppendVariablesRecursively: def test_append_nested_dict_value(self): """Test appending a nested dictionary value""" - pool = VariablePool() + pool = VariablePool.empty() node_id = "test_node" variable_key_list = ["data"] @@ -97,7 +97,7 @@ class TestAppendVariablesRecursively: def test_append_non_dict_value(self): """Test appending a non-dictionary value (should not recurse)""" - pool = VariablePool() + pool = VariablePool.empty() node_id = "test_node" variable_key_list = ["simple"] variable_value = "simple_string" @@ -114,7 +114,7 @@ class TestAppendVariablesRecursively: def test_append_segment_non_object_value(self): """Test appending a Segment that is not ObjectSegment (should not recurse)""" - pool = VariablePool() + pool = VariablePool.empty() node_id = "test_node" variable_key_list = ["text"] variable_value = StringSegment(value="Hello World") @@ -132,7 +132,7 @@ class TestAppendVariablesRecursively: def test_append_empty_dict_value(self): """Test appending an empty dictionary value""" - pool = VariablePool() + pool = VariablePool.empty() node_id = "test_node" variable_key_list = ["empty"] variable_value: dict[str, Any] = {} diff --git a/api/tests/unit_tests/factories/test_variable_factory.py b/api/tests/unit_tests/factories/test_variable_factory.py index edd4c5e93..4f2542a32 100644 --- a/api/tests/unit_tests/factories/test_variable_factory.py +++ b/api/tests/unit_tests/factories/test_variable_factory.py @@ -505,8 +505,8 @@ def test_build_segment_type_for_scalar(): size=1000, ) cases = [ - TestCase(0, SegmentType.NUMBER), - TestCase(0.0, SegmentType.NUMBER), + TestCase(0, SegmentType.INTEGER), + TestCase(0.0, SegmentType.FLOAT), TestCase("", SegmentType.STRING), TestCase(file, SegmentType.FILE), ] @@ -531,14 +531,14 @@ class TestBuildSegmentWithType: result = build_segment_with_type(SegmentType.NUMBER, 42) assert isinstance(result, IntegerSegment) assert result.value == 42 - assert result.value_type == SegmentType.NUMBER + assert result.value_type == SegmentType.INTEGER def test_number_type_float(self): """Test building a number segment with float value.""" result = build_segment_with_type(SegmentType.NUMBER, 3.14) assert isinstance(result, FloatSegment) assert result.value == 3.14 - assert result.value_type == SegmentType.NUMBER + assert result.value_type == SegmentType.FLOAT def test_object_type(self): """Test building an object segment with correct type.""" @@ -652,14 +652,14 @@ class TestBuildSegmentWithType: with pytest.raises(TypeMismatchError) as exc_info: build_segment_with_type(SegmentType.STRING, None) - assert "Expected string, but got None" in str(exc_info.value) + assert "expected string, but got None" in str(exc_info.value) def test_type_mismatch_empty_list_to_non_array(self): """Test type mismatch when expecting non-array type but getting empty list.""" with pytest.raises(TypeMismatchError) as exc_info: build_segment_with_type(SegmentType.STRING, []) - assert "Expected string, but got empty list" in str(exc_info.value) + assert "expected string, but got empty list" in str(exc_info.value) def test_type_mismatch_object_to_array(self): """Test type mismatch when expecting array but getting object.""" @@ -674,19 +674,19 @@ class TestBuildSegmentWithType: # Integer should work result_int = build_segment_with_type(SegmentType.NUMBER, 42) assert isinstance(result_int, IntegerSegment) - assert result_int.value_type == SegmentType.NUMBER + assert result_int.value_type == SegmentType.INTEGER # Float should work result_float = build_segment_with_type(SegmentType.NUMBER, 3.14) assert isinstance(result_float, FloatSegment) - assert result_float.value_type == SegmentType.NUMBER + assert result_float.value_type == SegmentType.FLOAT @pytest.mark.parametrize( ("segment_type", "value", "expected_class"), [ (SegmentType.STRING, "test", StringSegment), - (SegmentType.NUMBER, 42, IntegerSegment), - (SegmentType.NUMBER, 3.14, FloatSegment), + (SegmentType.INTEGER, 42, IntegerSegment), + (SegmentType.FLOAT, 3.14, FloatSegment), (SegmentType.OBJECT, {}, ObjectSegment), (SegmentType.NONE, None, NoneSegment), (SegmentType.ARRAY_STRING, [], ArrayStringSegment), @@ -857,5 +857,5 @@ class TestBuildSegmentValueErrors: # Verify they are processed as integers, not as errors assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1" assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0" - assert true_segment.value_type == SegmentType.NUMBER - assert false_segment.value_type == SegmentType.NUMBER + assert true_segment.value_type == SegmentType.INTEGER + assert false_segment.value_type == SegmentType.INTEGER diff --git a/web/app/components/base/chat/chat/question.tsx b/web/app/components/base/chat/chat/question.tsx index d22158794..cae8e2b8c 100644 --- a/web/app/components/base/chat/chat/question.tsx +++ b/web/app/components/base/chat/chat/question.tsx @@ -98,7 +98,7 @@ const Question: FC = ({ return (
-
+
= ({
{ diff --git a/web/app/components/workflow/nodes/agent/components/tool-icon.tsx b/web/app/components/workflow/nodes/agent/components/tool-icon.tsx index 8616f3420..4ff0cd780 100644 --- a/web/app/components/workflow/nodes/agent/components/tool-icon.tsx +++ b/web/app/components/workflow/nodes/agent/components/tool-icon.tsx @@ -61,7 +61,7 @@ export const ToolIcon = memo(({ providerName }: ToolIconProps) => { >
@@ -73,7 +73,7 @@ export const ToolIcon = memo(({ providerName }: ToolIconProps) => { src={icon} alt='tool icon' className={classNames( - 'w-full h-full size-3.5 object-cover', + 'size-3.5 h-full w-full object-cover', notSuccess && 'opacity-50', )} onError={() => setIconFetchError(true)} @@ -82,7 +82,7 @@ export const ToolIcon = memo(({ providerName }: ToolIconProps) => { if (typeof icon === 'object') { return