diff --git a/api/child_class.py b/api/child_class.py new file mode 100644 index 000000000..b210607b9 --- /dev/null +++ b/api/child_class.py @@ -0,0 +1,11 @@ +from tests.integration_tests.utils.parent_class import ParentClass + + +class ChildClass(ParentClass): + """Test child class for module import helper tests""" + + def __init__(self, name): + super().__init__(name) + + def get_name(self): + return f"Child: {self.name}" diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index 2f2445a33..637573344 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -3,6 +3,17 @@ import re from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType from core.external_data_tool.factory import ExternalDataToolFactory +_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset( + [ + VariableEntityType.TEXT_INPUT, + VariableEntityType.SELECT, + VariableEntityType.PARAGRAPH, + VariableEntityType.NUMBER, + VariableEntityType.EXTERNAL_DATA_TOOL, + VariableEntityType.CHECKBOX, + ] +) + class BasicVariablesConfigManager: @classmethod @@ -47,6 +58,7 @@ class BasicVariablesConfigManager: VariableEntityType.PARAGRAPH, VariableEntityType.NUMBER, VariableEntityType.SELECT, + VariableEntityType.CHECKBOX, }: variable = variables[variable_type] variable_entities.append( @@ -96,8 +108,17 @@ class BasicVariablesConfigManager: variables = [] for item in config["user_input_form"]: key = list(item.keys())[0] - if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}: - raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") + # if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}: + if key not in { + VariableEntityType.TEXT_INPUT, + VariableEntityType.SELECT, + VariableEntityType.PARAGRAPH, + VariableEntityType.NUMBER, + VariableEntityType.EXTERNAL_DATA_TOOL, + VariableEntityType.CHECKBOX, + }: + allowed_keys = ", ".join(i.value for i in _ALLOWED_VARIABLE_ENTITY_TYPE) + raise ValueError(f"Keys in user_input_form list can only be {allowed_keys}") form_item = item[key] if "label" not in form_item: diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 0db1d5277..df2074df2 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -97,6 +97,7 @@ class VariableEntityType(StrEnum): EXTERNAL_DATA_TOOL = "external_data_tool" FILE = "file" FILE_LIST = "file-list" + CHECKBOX = "checkbox" class VariableEntity(BaseModel): diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index beece1d77..42634fc48 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -103,18 +103,23 @@ class BaseAppGenerator: f"(type '{variable_entity.type}') {variable_entity.variable} in input form must be a string" ) - if variable_entity.type == VariableEntityType.NUMBER and isinstance(value, str): - # handle empty string case - if not value.strip(): - return None - # may raise ValueError if user_input_value is not a valid number - try: - if "." in value: - return float(value) - else: - return int(value) - except ValueError: - raise ValueError(f"{variable_entity.variable} in input form must be a valid number") + if variable_entity.type == VariableEntityType.NUMBER: + if isinstance(value, (int, float)): + return value + elif isinstance(value, str): + # handle empty string case + if not value.strip(): + return None + # may raise ValueError if user_input_value is not a valid number + try: + if "." in value: + return float(value) + else: + return int(value) + except ValueError: + raise ValueError(f"{variable_entity.variable} in input form must be a valid number") + else: + raise TypeError(f"expected value type int, float or str, got {type(value)}, value: {value}") match variable_entity.type: case VariableEntityType.SELECT: @@ -144,6 +149,11 @@ class BaseAppGenerator: raise ValueError( f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files" ) + case VariableEntityType.CHECKBOX: + if not isinstance(value, bool): + raise ValueError(f"{variable_entity.variable} in input form must be a valid boolean value") + case _: + raise AssertionError("this statement should be unreachable.") return value diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index a99f5eece..9e7616874 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -151,6 +151,11 @@ class FileSegment(Segment): return "" +class BooleanSegment(Segment): + value_type: SegmentType = SegmentType.BOOLEAN + value: bool + + class ArrayAnySegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_ANY value: Sequence[Any] @@ -198,6 +203,11 @@ class ArrayFileSegment(ArraySegment): return "" +class ArrayBooleanSegment(ArraySegment): + value_type: SegmentType = SegmentType.ARRAY_BOOLEAN + value: Sequence[bool] + + def get_segment_discriminator(v: Any) -> SegmentType | None: if isinstance(v, Segment): return v.value_type @@ -231,11 +241,13 @@ SegmentUnion: TypeAlias = Annotated[ | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)] | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)] | Annotated[FileSegment, Tag(SegmentType.FILE)] + | Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)] | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)] | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)] | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)] | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)] | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)] + | Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)] ), Discriminator(get_segment_discriminator), ] diff --git a/api/core/variables/types.py b/api/core/variables/types.py index 662905604..55f8ae3c7 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -6,7 +6,12 @@ from core.file.models import File class ArrayValidation(StrEnum): - """Strategy for validating array elements""" + """Strategy for validating array elements. + + Note: + The `NONE` and `FIRST` strategies are primarily for compatibility purposes. + Avoid using them in new code whenever possible. + """ # Skip element validation (only check array container) NONE = "none" @@ -27,12 +32,14 @@ class SegmentType(StrEnum): SECRET = "secret" FILE = "file" + BOOLEAN = "boolean" ARRAY_ANY = "array[any]" ARRAY_STRING = "array[string]" ARRAY_NUMBER = "array[number]" ARRAY_OBJECT = "array[object]" ARRAY_FILE = "array[file]" + ARRAY_BOOLEAN = "array[boolean]" NONE = "none" @@ -76,12 +83,18 @@ class SegmentType(StrEnum): return SegmentType.ARRAY_FILE case SegmentType.NONE: return SegmentType.ARRAY_ANY + case SegmentType.BOOLEAN: + return SegmentType.ARRAY_BOOLEAN case _: # This should be unreachable. raise ValueError(f"not supported value {value}") if value is None: return SegmentType.NONE - elif isinstance(value, int) and not isinstance(value, bool): + # Important: The check for `bool` must precede the check for `int`, + # as `bool` is a subclass of `int` in Python's type hierarchy. + elif isinstance(value, bool): + return SegmentType.BOOLEAN + elif isinstance(value, int): return SegmentType.INTEGER elif isinstance(value, float): return SegmentType.FLOAT @@ -111,7 +124,7 @@ class SegmentType(StrEnum): else: return all(element_type.is_valid(i, array_validation=ArrayValidation.NONE) for i in value) - def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool: + def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.ALL) -> bool: """ Check if a value matches the segment type. Users of `SegmentType` should call this method, instead of using @@ -126,6 +139,10 @@ class SegmentType(StrEnum): """ if self.is_array_type(): return self._validate_array(value, array_validation) + # Important: The check for `bool` must precede the check for `int`, + # as `bool` is a subclass of `int` in Python's type hierarchy. + elif self == SegmentType.BOOLEAN: + return isinstance(value, bool) elif self in [SegmentType.INTEGER, SegmentType.FLOAT, SegmentType.NUMBER]: return isinstance(value, (int, float)) elif self == SegmentType.STRING: @@ -141,6 +158,27 @@ class SegmentType(StrEnum): else: raise AssertionError("this statement should be unreachable.") + @staticmethod + def cast_value(value: Any, type_: "SegmentType") -> Any: + # Cast Python's `bool` type to `int` when the runtime type requires + # an integer or number. + # + # This ensures compatibility with existing workflows that may use `bool` as + # `int`, since in Python's type system, `bool` is a subtype of `int`. + # + # This function exists solely to maintain compatibility with existing workflows. + # It should not be used to compromise the integrity of the runtime type system. + # No additional casting rules should be introduced to this function. + + if type_ in ( + SegmentType.INTEGER, + SegmentType.NUMBER, + ) and isinstance(value, bool): + return int(value) + if type_ == SegmentType.ARRAY_NUMBER and all(isinstance(i, bool) for i in value): + return [int(i) for i in value] + return value + def exposed_type(self) -> "SegmentType": """Returns the type exposed to the frontend. @@ -150,6 +188,20 @@ class SegmentType(StrEnum): return SegmentType.NUMBER return self + def element_type(self) -> "SegmentType | None": + """Return the element type of the current segment type, or `None` if the element type is undefined. + + Raises: + ValueError: If the current segment type is not an array type. + + Note: + For certain array types, such as `SegmentType.ARRAY_ANY`, their element types are not defined + by the runtime system. In such cases, this method will return `None`. + """ + if not self.is_array_type(): + raise ValueError(f"element_type is only supported by array type, got {self}") + return _ARRAY_ELEMENT_TYPES_MAPPING.get(self) + _ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = { # ARRAY_ANY does not have corresponding element type. @@ -157,6 +209,7 @@ _ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = { SegmentType.ARRAY_NUMBER: SegmentType.NUMBER, SegmentType.ARRAY_OBJECT: SegmentType.OBJECT, SegmentType.ARRAY_FILE: SegmentType.FILE, + SegmentType.ARRAY_BOOLEAN: SegmentType.BOOLEAN, } _ARRAY_TYPES = frozenset( diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index a31ebc848..16c8116ac 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -8,11 +8,13 @@ from core.helper import encrypter from .segments import ( ArrayAnySegment, + ArrayBooleanSegment, ArrayFileSegment, ArrayNumberSegment, ArrayObjectSegment, ArraySegment, ArrayStringSegment, + BooleanSegment, FileSegment, FloatSegment, IntegerSegment, @@ -96,10 +98,18 @@ class FileVariable(FileSegment, Variable): pass +class BooleanVariable(BooleanSegment, Variable): + pass + + class ArrayFileVariable(ArrayFileSegment, ArrayVariable): pass +class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable): + pass + + # The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic. # Use `Variable` for type hinting when serialization is not required. # @@ -114,11 +124,13 @@ VariableUnion: TypeAlias = Annotated[ | Annotated[IntegerVariable, Tag(SegmentType.INTEGER)] | Annotated[ObjectVariable, Tag(SegmentType.OBJECT)] | Annotated[FileVariable, Tag(SegmentType.FILE)] + | Annotated[BooleanVariable, Tag(SegmentType.BOOLEAN)] | Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)] | Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)] | Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)] | Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)] | Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)] + | Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)] | Annotated[SecretVariable, Tag(SegmentType.SECRET)] ), Discriminator(get_segment_discriminator), diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index fdf393282..17bd841fc 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -8,6 +8,7 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.variables.segments import ArrayFileSegment +from core.variables.types import SegmentType from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode @@ -119,6 +120,14 @@ class CodeNode(BaseNode): return value.replace("\x00", "") + def _check_boolean(self, value: bool | None, variable: str) -> bool | None: + if value is None: + return None + if not isinstance(value, bool): + raise OutputValidationError(f"Output variable `{variable}` must be a boolean") + + return value + def _check_number(self, value: int | float | None, variable: str) -> int | float | None: """ Check number @@ -173,6 +182,8 @@ class CodeNode(BaseNode): prefix=f"{prefix}.{output_name}" if prefix else output_name, depth=depth + 1, ) + elif isinstance(output_value, bool): + self._check_boolean(output_value, variable=f"{prefix}.{output_name}" if prefix else output_name) elif isinstance(output_value, int | float): self._check_number( value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name @@ -232,7 +243,7 @@ class CodeNode(BaseNode): if output_name not in result: raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.") - if output_config.type == "object": + if output_config.type == SegmentType.OBJECT: # check if output is object if not isinstance(result.get(output_name), dict): if result[output_name] is None: @@ -249,18 +260,28 @@ class CodeNode(BaseNode): prefix=f"{prefix}.{output_name}", depth=depth + 1, ) - elif output_config.type == "number": + elif output_config.type == SegmentType.NUMBER: # check if number available - transformed_result[output_name] = self._check_number( - value=result[output_name], variable=f"{prefix}{dot}{output_name}" - ) - elif output_config.type == "string": + checked = self._check_number(value=result[output_name], variable=f"{prefix}{dot}{output_name}") + # If the output is a boolean and the output schema specifies a NUMBER type, + # convert the boolean value to an integer. + # + # This ensures compatibility with existing workflows that may use + # `True` and `False` as values for NUMBER type outputs. + transformed_result[output_name] = self._convert_boolean_to_int(checked) + + elif output_config.type == SegmentType.STRING: # check if string available transformed_result[output_name] = self._check_string( value=result[output_name], variable=f"{prefix}{dot}{output_name}", ) - elif output_config.type == "array[number]": + elif output_config.type == SegmentType.BOOLEAN: + transformed_result[output_name] = self._check_boolean( + value=result[output_name], + variable=f"{prefix}{dot}{output_name}", + ) + elif output_config.type == SegmentType.ARRAY_NUMBER: # check if array of number available if not isinstance(result[output_name], list): if result[output_name] is None: @@ -278,10 +299,17 @@ class CodeNode(BaseNode): ) transformed_result[output_name] = [ - self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]") + # If the element is a boolean and the output schema specifies a `array[number]` type, + # convert the boolean value to an integer. + # + # This ensures compatibility with existing workflows that may use + # `True` and `False` as values for NUMBER type outputs. + self._convert_boolean_to_int( + self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]"), + ) for i, value in enumerate(result[output_name]) ] - elif output_config.type == "array[string]": + elif output_config.type == SegmentType.ARRAY_STRING: # check if array of string available if not isinstance(result[output_name], list): if result[output_name] is None: @@ -302,7 +330,7 @@ class CodeNode(BaseNode): self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]") for i, value in enumerate(result[output_name]) ] - elif output_config.type == "array[object]": + elif output_config.type == SegmentType.ARRAY_OBJECT: # check if array of object available if not isinstance(result[output_name], list): if result[output_name] is None: @@ -340,6 +368,22 @@ class CodeNode(BaseNode): ) for i, value in enumerate(result[output_name]) ] + elif output_config.type == SegmentType.ARRAY_BOOLEAN: + # check if array of object available + if not isinstance(result[output_name], list): + if result[output_name] is None: + transformed_result[output_name] = None + else: + raise OutputValidationError( + f"Output {prefix}{dot}{output_name} is not an array," + f" got {type(result.get(output_name))} instead." + ) + else: + transformed_result[output_name] = [ + self._check_boolean(value=value, variable=f"{prefix}{dot}{output_name}[{i}]") + for i, value in enumerate(result[output_name]) + ] + else: raise OutputValidationError(f"Output type {output_config.type} is not supported.") @@ -374,3 +418,16 @@ class CodeNode(BaseNode): @property def retry(self) -> bool: return self._node_data.retry_config.retry_enabled + + @staticmethod + def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None: + """This function convert boolean to integers when the output schema specifies a NUMBER type. + + This ensures compatibility with existing workflows that may use + `True` and `False` as values for NUMBER type outputs. + """ + if value is None: + return None + if isinstance(value, bool): + return int(value) + return value diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index a45403588..9d380c6fb 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -1,11 +1,31 @@ -from typing import Literal, Optional +from typing import Annotated, Literal, Optional -from pydantic import BaseModel +from pydantic import AfterValidator, BaseModel from core.helper.code_executor.code_executor import CodeLanguage +from core.variables.types import SegmentType from core.workflow.entities.variable_entities import VariableSelector from core.workflow.nodes.base import BaseNodeData +_ALLOWED_OUTPUT_FROM_CODE = frozenset( + [ + SegmentType.STRING, + SegmentType.NUMBER, + SegmentType.OBJECT, + SegmentType.BOOLEAN, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_OBJECT, + SegmentType.ARRAY_BOOLEAN, + ] +) + + +def _validate_type(segment_type: SegmentType) -> SegmentType: + if segment_type not in _ALLOWED_OUTPUT_FROM_CODE: + raise ValueError(f"invalid type for code output, expected {_ALLOWED_OUTPUT_FROM_CODE}, actual {segment_type}") + return segment_type + class CodeNodeData(BaseNodeData): """ @@ -13,7 +33,7 @@ class CodeNodeData(BaseNodeData): """ class Output(BaseModel): - type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] + type: Annotated[SegmentType, AfterValidator(_validate_type)] children: Optional[dict[str, "CodeNodeData.Output"]] = None class Dependency(BaseModel): diff --git a/api/core/workflow/nodes/list_operator/entities.py b/api/core/workflow/nodes/list_operator/entities.py index 75df784a9..e51a91f07 100644 --- a/api/core/workflow/nodes/list_operator/entities.py +++ b/api/core/workflow/nodes/list_operator/entities.py @@ -1,36 +1,43 @@ from collections.abc import Sequence -from typing import Literal +from enum import StrEnum from pydantic import BaseModel, Field from core.workflow.nodes.base import BaseNodeData -_Condition = Literal[ + +class FilterOperator(StrEnum): # string conditions - "contains", - "start with", - "end with", - "is", - "in", - "empty", - "not contains", - "is not", - "not in", - "not empty", + CONTAINS = "contains" + START_WITH = "start with" + END_WITH = "end with" + IS = "is" + IN = "in" + EMPTY = "empty" + NOT_CONTAINS = "not contains" + IS_NOT = "is not" + NOT_IN = "not in" + NOT_EMPTY = "not empty" # number conditions - "=", - "≠", - "<", - ">", - "≥", - "≤", -] + EQUAL = "=" + NOT_EQUAL = "≠" + LESS_THAN = "<" + GREATER_THAN = ">" + GREATER_THAN_OR_EQUAL = "≥" + LESS_THAN_OR_EQUAL = "≤" + + +class Order(StrEnum): + ASC = "asc" + DESC = "desc" class FilterCondition(BaseModel): key: str = "" - comparison_operator: _Condition = "contains" - value: str | Sequence[str] = "" + comparison_operator: FilterOperator = FilterOperator.CONTAINS + # the value is bool if the filter operator is comparing with + # a boolean constant. + value: str | Sequence[str] | bool = "" class FilterBy(BaseModel): @@ -38,10 +45,10 @@ class FilterBy(BaseModel): conditions: Sequence[FilterCondition] = Field(default_factory=list) -class OrderBy(BaseModel): +class OrderByConfig(BaseModel): enabled: bool = False key: str = "" - value: Literal["asc", "desc"] = "asc" + value: Order = Order.ASC class Limit(BaseModel): @@ -57,6 +64,6 @@ class ExtractConfig(BaseModel): class ListOperatorNodeData(BaseNodeData): variable: Sequence[str] = Field(default_factory=list) filter_by: FilterBy - order_by: OrderBy + order_by: OrderByConfig limit: Limit extract_by: ExtractConfig = Field(default_factory=ExtractConfig) diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index d2e022dc9..a727a826c 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -1,18 +1,40 @@ from collections.abc import Callable, Mapping, Sequence -from typing import Any, Literal, Optional, Union +from typing import Any, Optional, TypeAlias, TypeVar from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment -from core.variables.segments import ArrayAnySegment, ArraySegment +from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.enums import ErrorStrategy, NodeType -from .entities import ListOperatorNodeData +from .entities import FilterOperator, ListOperatorNodeData, Order from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError +_SUPPORTED_TYPES_TUPLE = ( + ArrayFileSegment, + ArrayNumberSegment, + ArrayStringSegment, + ArrayBooleanSegment, +) +_SUPPORTED_TYPES_ALIAS: TypeAlias = ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment | ArrayBooleanSegment + + +_T = TypeVar("_T") + + +def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]: + """Returns the negation of a given filter function. If the original filter + returns `True` for a value, the negated filter will return `False`, and vice versa. + """ + + def wrapper(value: _T) -> bool: + return not filter_(value) + + return wrapper + class ListOperatorNode(BaseNode): _node_type = NodeType.LIST_OPERATOR @@ -69,11 +91,8 @@ class ListOperatorNode(BaseNode): process_data=process_data, outputs=outputs, ) - if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment): - error_message = ( - f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment " - "or ArrayStringSegment" - ) + if not isinstance(variable, _SUPPORTED_TYPES_TUPLE): + error_message = f"Variable {self._node_data.variable} is not an array type, actual type: {type(variable)}" return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs ) @@ -122,9 +141,7 @@ class ListOperatorNode(BaseNode): outputs=outputs, ) - def _apply_filter( - self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] - ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: filter_func: Callable[[Any], bool] result: list[Any] = [] for condition in self._node_data.filter_by.conditions: @@ -154,33 +171,35 @@ class ListOperatorNode(BaseNode): ) result = list(filter(filter_func, variable.value)) variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayBooleanSegment): + if not isinstance(condition.value, bool): + raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") + filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value) + result = list(filter(filter_func, variable.value)) + variable = variable.model_copy(update={"value": result}) + else: + raise AssertionError("this statment should be unreachable.") return variable - def _apply_order( - self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] - ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: - if isinstance(variable, ArrayStringSegment): - result = _order_string(order=self._node_data.order_by.value, array=variable.value) - variable = variable.model_copy(update={"value": result}) - elif isinstance(variable, ArrayNumberSegment): - result = _order_number(order=self._node_data.order_by.value, array=variable.value) + def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: + if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)): + result = sorted(variable.value, reverse=self._node_data.order_by == Order.DESC) variable = variable.model_copy(update={"value": result}) elif isinstance(variable, ArrayFileSegment): result = _order_file( order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value ) variable = variable.model_copy(update={"value": result}) + else: + raise AssertionError("this statement should be unreachable") + return variable - def _apply_slice( - self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] - ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: result = variable.value[: self._node_data.limit.size] return variable.model_copy(update={"value": result}) - def _extract_slice( - self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] - ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text) if value < 1: raise ValueError(f"Invalid serial index: must be >= 1, got {value}") @@ -232,11 +251,11 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo case "empty": return lambda x: x == "" case "not contains": - return lambda x: not _contains(value)(x) + return _negation(_contains(value)) case "is not": - return lambda x: not _is(value)(x) + return _negation(_is(value)) case "not in": - return lambda x: not _in(value)(x) + return _negation(_in(value)) case "not empty": return lambda x: x != "" case _: @@ -248,7 +267,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab case "in": return _in(value) case "not in": - return lambda x: not _in(value)(x) + return _negation(_in(value)) case _: raise InvalidConditionError(f"Invalid condition: {condition}") @@ -271,6 +290,16 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[ raise InvalidConditionError(f"Invalid condition: {condition}") +def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Callable[[bool], bool]: + match condition: + case FilterOperator.IS: + return _is(value) + case FilterOperator.IS_NOT: + return _negation(_is(value)) + case _: + raise InvalidConditionError(f"Invalid condition: {condition}") + + def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: extract_func: Callable[[File], Any] if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str): @@ -298,7 +327,7 @@ def _endswith(value: str) -> Callable[[str], bool]: return lambda x: x.endswith(value) -def _is(value: str) -> Callable[[str], bool]: +def _is(value: _T) -> Callable[[_T], bool]: return lambda x: x == value @@ -330,21 +359,13 @@ def _ge(value: int | float) -> Callable[[int | float], bool]: return lambda x: x >= value -def _order_number(*, order: Literal["asc", "desc"], array: Sequence[int | float]): - return sorted(array, key=lambda x: x, reverse=order == "desc") - - -def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]): - return sorted(array, key=lambda x: x, reverse=order == "desc") - - -def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]): +def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]): extract_func: Callable[[File], Any] if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}: extract_func = _get_file_extract_string_func(key=order_by) - return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc") + return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC) elif order_by == "size": extract_func = _get_file_extract_number_func(key=order_by) - return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc") + return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC) else: raise InvalidKeyError(f"Invalid order key: {order_by}") diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index ecfbec703..10059fdcb 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -3,7 +3,7 @@ import io import json import logging from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import FileType, file_manager @@ -55,7 +55,6 @@ from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.enums import SystemVariableKey -from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.enums import ErrorStrategy, NodeType @@ -90,6 +89,7 @@ from .file_saver import FileSaverImpl, LLMFileSaver if TYPE_CHECKING: from core.file.models import File from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState + from core.workflow.graph_engine.entities.event import InNodeEvent logger = logging.getLogger(__name__) @@ -161,7 +161,7 @@ class LLMNode(BaseNode): def version(cls) -> str: return "1" - def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: + def _run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]: node_inputs: Optional[dict[str, Any]] = None process_data = None result_text = "" diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py index d04e0bfae..3ed4d21ba 100644 --- a/api/core/workflow/nodes/loop/entities.py +++ b/api/core/workflow/nodes/loop/entities.py @@ -12,9 +12,11 @@ _VALID_VAR_TYPE = frozenset( SegmentType.STRING, SegmentType.NUMBER, SegmentType.OBJECT, + SegmentType.BOOLEAN, SegmentType.ARRAY_STRING, SegmentType.ARRAY_NUMBER, SegmentType.ARRAY_OBJECT, + SegmentType.ARRAY_BOOLEAN, ] ) diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index b2ab94312..3e52a3218 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -404,11 +404,11 @@ class LoopNode(BaseNode): for node_id in loop_graph.node_ids: variable_pool.remove([node_id]) - _outputs = {} + _outputs: dict[str, Segment | int | None] = {} for loop_variable_key, loop_variable_selector in loop_variable_selectors.items(): _loop_variable_segment = variable_pool.get(loop_variable_selector) if _loop_variable_segment: - _outputs[loop_variable_key] = _loop_variable_segment.value + _outputs[loop_variable_key] = _loop_variable_segment else: _outputs[loop_variable_key] = None @@ -522,21 +522,30 @@ class LoopNode(BaseNode): return variable_mapping @staticmethod - def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment: + def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment: """Get the appropriate segment type for a constant value.""" - if var_type in ["array[string]", "array[number]", "array[object]"]: - if value and isinstance(value, str): - value = json.loads(value) + if var_type in [ + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_OBJECT, + SegmentType.ARRAY_STRING, + ]: + if original_value and isinstance(original_value, str): + value = json.loads(original_value) else: + logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type) value = [] + elif var_type == SegmentType.ARRAY_BOOLEAN: + value = original_value + else: + raise AssertionError("this statement should be unreachable.") try: - return build_segment_with_type(var_type, value) + return build_segment_with_type(var_type, value=value) except TypeMismatchError as type_exc: # Attempt to parse the value as a JSON-encoded string, if applicable. - if not isinstance(value, str): + if not isinstance(original_value, str): raise try: - value = json.loads(value) + value = json.loads(original_value) except ValueError: raise type_exc return build_segment_with_type(var_type, value) diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 916778d16..12347d21a 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -1,10 +1,46 @@ -from typing import Any, Literal, Optional +from typing import Annotated, Any, Literal, Optional -from pydantic import BaseModel, Field, field_validator +from pydantic import ( + BaseModel, + BeforeValidator, + Field, + field_validator, +) from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.variables.types import SegmentType from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.llm import ModelConfig, VisionConfig +from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig + +_OLD_BOOL_TYPE_NAME = "bool" +_OLD_SELECT_TYPE_NAME = "select" + +_VALID_PARAMETER_TYPES = frozenset( + [ + SegmentType.STRING, # "string", + SegmentType.NUMBER, # "number", + SegmentType.BOOLEAN, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_OBJECT, + SegmentType.ARRAY_BOOLEAN, + _OLD_BOOL_TYPE_NAME, # old boolean type used by Parameter Extractor node + _OLD_SELECT_TYPE_NAME, # string type with enumeration choices. + ] +) + + +def _validate_type(parameter_type: str) -> SegmentType: + if not isinstance(parameter_type, str): + raise TypeError(f"type should be str, got {type(parameter_type)}, value={parameter_type}") + if parameter_type not in _VALID_PARAMETER_TYPES: + raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.") + + if parameter_type == _OLD_BOOL_TYPE_NAME: + return SegmentType.BOOLEAN + elif parameter_type == _OLD_SELECT_TYPE_NAME: + return SegmentType.STRING + return SegmentType(parameter_type) class _ParameterConfigError(Exception): @@ -17,7 +53,7 @@ class ParameterConfig(BaseModel): """ name: str - type: Literal["string", "number", "bool", "select", "array[string]", "array[number]", "array[object]"] + type: Annotated[SegmentType, BeforeValidator(_validate_type)] options: Optional[list[str]] = None description: str required: bool @@ -32,17 +68,20 @@ class ParameterConfig(BaseModel): return str(value) def is_array_type(self) -> bool: - return self.type in ("array[string]", "array[number]", "array[object]") + return self.type.is_array_type() - def element_type(self) -> Literal["string", "number", "object"]: - if self.type == "array[number]": - return "number" - elif self.type == "array[string]": - return "string" - elif self.type == "array[object]": - return "object" - else: - raise _ParameterConfigError(f"{self.type} is not array type.") + def element_type(self) -> SegmentType: + """Return the element type of the parameter. + + Raises a ValueError if the parameter's type is not an array type. + """ + element_type = self.type.element_type() + # At this point, self.type is guaranteed to be one of `ARRAY_STRING`, + # `ARRAY_NUMBER`, `ARRAY_OBJECT`, or `ARRAY_BOOLEAN`. + # + # See: _VALID_PARAMETER_TYPES for reference. + assert element_type is not None, f"the element type should not be None, {self.type=}" + return element_type class ParameterExtractorNodeData(BaseNodeData): @@ -74,16 +113,18 @@ class ParameterExtractorNodeData(BaseNodeData): for parameter in self.parameters: parameter_schema: dict[str, Any] = {"description": parameter.description} - if parameter.type in {"string", "select"}: + if parameter.type == SegmentType.STRING: parameter_schema["type"] = "string" - elif parameter.type.startswith("array"): + elif parameter.type.is_array_type(): parameter_schema["type"] = "array" - nested_type = parameter.type[6:-1] - parameter_schema["items"] = {"type": nested_type} + element_type = parameter.type.element_type() + if element_type is None: + raise AssertionError("element type should not be None.") + parameter_schema["items"] = {"type": element_type.value} else: parameter_schema["type"] = parameter.type - if parameter.type == "select": + if parameter.options: parameter_schema["enum"] = parameter.options parameters["properties"][parameter.name] = parameter_schema diff --git a/api/core/workflow/nodes/parameter_extractor/exc.py b/api/core/workflow/nodes/parameter_extractor/exc.py index 6511aba18..247518cf2 100644 --- a/api/core/workflow/nodes/parameter_extractor/exc.py +++ b/api/core/workflow/nodes/parameter_extractor/exc.py @@ -1,3 +1,8 @@ +from typing import Any + +from core.variables.types import SegmentType + + class ParameterExtractorNodeError(ValueError): """Base error for ParameterExtractorNode.""" @@ -48,3 +53,23 @@ class InvalidArrayValueError(ParameterExtractorNodeError): class InvalidModelModeError(ParameterExtractorNodeError): """Raised when the model mode is invalid.""" + + +class InvalidValueTypeError(ParameterExtractorNodeError): + def __init__( + self, + /, + parameter_name: str, + expected_type: SegmentType, + actual_type: SegmentType | None, + value: Any, + ) -> None: + message = ( + f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, " + f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}" + ) + super().__init__(message) + self.parameter_name = parameter_name + self.expected_type = expected_type + self.actual_type = actual_type + self.value = value diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 49c4c142e..3dcde5ad8 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -26,7 +26,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.variables.types import SegmentType +from core.variables.types import ArrayValidation, SegmentType from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -39,16 +39,13 @@ from factories.variable_factory import build_segment_with_type from .entities import ParameterExtractorNodeData from .exc import ( - InvalidArrayValueError, - InvalidBoolValueError, InvalidInvokeResultError, InvalidModelModeError, InvalidModelTypeError, InvalidNumberOfParametersError, - InvalidNumberValueError, InvalidSelectValueError, - InvalidStringValueError, InvalidTextContentTypeError, + InvalidValueTypeError, ModelSchemaNotFoundError, ParameterExtractorNodeError, RequiredParameterMissingError, @@ -549,9 +546,6 @@ class ParameterExtractorNode(BaseNode): return prompt_messages def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict: - """ - Validate result. - """ if len(data.parameters) != len(result): raise InvalidNumberOfParametersError("Invalid number of parameters") @@ -559,101 +553,106 @@ class ParameterExtractorNode(BaseNode): if parameter.required and parameter.name not in result: raise RequiredParameterMissingError(f"Parameter {parameter.name} is required") - if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options: - raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}") - - if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float): - raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}") - - if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool): - raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}") - - if parameter.type == "string" and not isinstance(result.get(parameter.name), str): - raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}") - - if parameter.type.startswith("array"): - parameters = result.get(parameter.name) - if not isinstance(parameters, list): - raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}") - nested_type = parameter.type[6:-1] - for item in parameters: - if nested_type == "number" and not isinstance(item, int | float): - raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}") - if nested_type == "string" and not isinstance(item, str): - raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}") - if nested_type == "object" and not isinstance(item, dict): - raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}") + param_value = result.get(parameter.name) + if not parameter.type.is_valid(param_value, array_validation=ArrayValidation.ALL): + inferred_type = SegmentType.infer_segment_type(param_value) + raise InvalidValueTypeError( + parameter_name=parameter.name, + expected_type=parameter.type, + actual_type=inferred_type, + value=param_value, + ) + if parameter.type == SegmentType.STRING and parameter.options: + if param_value not in parameter.options: + raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}") return result + @staticmethod + def _transform_number(value: int | float | str | bool) -> int | float | None: + """ + Attempts to transform the input into an integer or float. + + Returns: + int or float: The transformed number if the conversion is successful. + None: If the transformation fails. + + Note: + Boolean values `True` and `False` are converted to integers `1` and `0`, respectively. + This behavior ensures compatibility with existing workflows that may use boolean types as integers. + """ + if isinstance(value, bool): + return int(value) + elif isinstance(value, (int, float)): + return value + elif not isinstance(value, str): + return None + if "." in value: + try: + return float(value) + except ValueError: + return None + else: + try: + return int(value) + except ValueError: + return None + def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict: """ Transform result into standard format. """ - transformed_result = {} + transformed_result: dict[str, Any] = {} for parameter in data.parameters: if parameter.name in result: + param_value = result[parameter.name] # transform value - if parameter.type == "number": - if isinstance(result[parameter.name], int | float): - transformed_result[parameter.name] = result[parameter.name] - elif isinstance(result[parameter.name], str): - try: - if "." in result[parameter.name]: - result[parameter.name] = float(result[parameter.name]) - else: - result[parameter.name] = int(result[parameter.name]) - except ValueError: - pass - else: - pass - # TODO: bool is not supported in the current version - # elif parameter.type == 'bool': - # if isinstance(result[parameter.name], bool): - # transformed_result[parameter.name] = bool(result[parameter.name]) - # elif isinstance(result[parameter.name], str): - # if result[parameter.name].lower() in ['true', 'false']: - # transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true') - # elif isinstance(result[parameter.name], int): - # transformed_result[parameter.name] = bool(result[parameter.name]) - elif parameter.type in {"string", "select"}: - if isinstance(result[parameter.name], str): - transformed_result[parameter.name] = result[parameter.name] + if parameter.type == SegmentType.NUMBER: + transformed = self._transform_number(param_value) + if transformed is not None: + transformed_result[parameter.name] = transformed + elif parameter.type == SegmentType.BOOLEAN: + if isinstance(result[parameter.name], (bool, int)): + transformed_result[parameter.name] = bool(result[parameter.name]) + # elif isinstance(result[parameter.name], str): + # if result[parameter.name].lower() in ["true", "false"]: + # transformed_result[parameter.name] = bool(result[parameter.name].lower() == "true") + elif parameter.type == SegmentType.STRING: + if isinstance(param_value, str): + transformed_result[parameter.name] = param_value elif parameter.is_array_type(): - if isinstance(result[parameter.name], list): + if isinstance(param_value, list): nested_type = parameter.element_type() assert nested_type is not None segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[]) transformed_result[parameter.name] = segment_value - for item in result[parameter.name]: - if nested_type == "number": - if isinstance(item, int | float): - segment_value.value.append(item) - elif isinstance(item, str): - try: - if "." in item: - segment_value.value.append(float(item)) - else: - segment_value.value.append(int(item)) - except ValueError: - pass - elif nested_type == "string": + for item in param_value: + if nested_type == SegmentType.NUMBER: + transformed = self._transform_number(item) + if transformed is not None: + segment_value.value.append(transformed) + elif nested_type == SegmentType.STRING: if isinstance(item, str): segment_value.value.append(item) - elif nested_type == "object": + elif nested_type == SegmentType.OBJECT: if isinstance(item, dict): segment_value.value.append(item) + elif nested_type == SegmentType.BOOLEAN: + if isinstance(item, bool): + segment_value.value.append(item) if parameter.name not in transformed_result: - if parameter.type == "number": - transformed_result[parameter.name] = 0 - elif parameter.type == "bool": - transformed_result[parameter.name] = False - elif parameter.type in {"string", "select"}: - transformed_result[parameter.name] = "" - elif parameter.type.startswith("array"): + if parameter.type.is_array_type(): transformed_result[parameter.name] = build_segment_with_type( segment_type=SegmentType(parameter.type), value=[] ) + elif parameter.type in (SegmentType.STRING, SegmentType.SECRET): + transformed_result[parameter.name] = "" + elif parameter.type == SegmentType.NUMBER: + transformed_result[parameter.name] = 0 + elif parameter.type == SegmentType.BOOLEAN: + transformed_result[parameter.name] = False + else: + raise AssertionError("this statement should be unreachable.") return transformed_result diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 51383fa58..321d280b1 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -2,6 +2,7 @@ from collections.abc import Callable, Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional, TypeAlias from core.variables import SegmentType, Variable +from core.variables.segments import BooleanSegment from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities.node_entities import NodeRunResult @@ -158,8 +159,8 @@ class VariableAssignerNode(BaseNode): def get_zero_value(t: SegmentType): # TODO(QuantumGhost): this should be a method of `SegmentType`. match t: - case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: - return variable_factory.build_segment([]) + case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER | SegmentType.ARRAY_BOOLEAN: + return variable_factory.build_segment_with_type(t, []) case SegmentType.OBJECT: return variable_factory.build_segment({}) case SegmentType.STRING: @@ -170,5 +171,7 @@ def get_zero_value(t: SegmentType): return variable_factory.build_segment(0.0) case SegmentType.NUMBER: return variable_factory.build_segment(0) + case SegmentType.BOOLEAN: + return BooleanSegment(value=False) case _: raise VariableOperatorNodeError(f"unsupported variable type: {t}") diff --git a/api/core/workflow/nodes/variable_assigner/v2/constants.py b/api/core/workflow/nodes/variable_assigner/v2/constants.py index 7f760e5ba..1a4b81c39 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/constants.py +++ b/api/core/workflow/nodes/variable_assigner/v2/constants.py @@ -4,9 +4,11 @@ from core.variables import SegmentType EMPTY_VALUE_MAPPING = { SegmentType.STRING: "", SegmentType.NUMBER: 0, + SegmentType.BOOLEAN: False, SegmentType.OBJECT: {}, SegmentType.ARRAY_ANY: [], SegmentType.ARRAY_STRING: [], SegmentType.ARRAY_NUMBER: [], SegmentType.ARRAY_OBJECT: [], + SegmentType.ARRAY_BOOLEAN: [], } diff --git a/api/core/workflow/nodes/variable_assigner/v2/helpers.py b/api/core/workflow/nodes/variable_assigner/v2/helpers.py index 7a20975b1..324f23a90 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/helpers.py +++ b/api/core/workflow/nodes/variable_assigner/v2/helpers.py @@ -16,28 +16,15 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation): SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT, + SegmentType.BOOLEAN, } case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE: # Only number variable can be added, subtracted, multiplied or divided return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT} - case Operation.APPEND | Operation.EXTEND: + case Operation.APPEND | Operation.EXTEND | Operation.REMOVE_FIRST | Operation.REMOVE_LAST: # Only array variable can be appended or extended - return variable_type in { - SegmentType.ARRAY_ANY, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_STRING, - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_FILE, - } - case Operation.REMOVE_FIRST | Operation.REMOVE_LAST: # Only array variable can have elements removed - return variable_type in { - SegmentType.ARRAY_ANY, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_STRING, - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_FILE, - } + return variable_type.is_array_type() case _: return False @@ -50,7 +37,7 @@ def is_variable_input_supported(*, operation: Operation): def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation): match variable_type: - case SegmentType.STRING | SegmentType.OBJECT: + case SegmentType.STRING | SegmentType.OBJECT | SegmentType.BOOLEAN: return operation in {Operation.OVER_WRITE, Operation.SET} case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: return operation in { @@ -72,6 +59,9 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va case SegmentType.STRING: return isinstance(value, str) + case SegmentType.BOOLEAN: + return isinstance(value, bool) + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: if not isinstance(value, int | float): return False @@ -91,6 +81,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va return isinstance(value, int | float) case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND: return isinstance(value, dict) + case SegmentType.ARRAY_BOOLEAN if operation == Operation.APPEND: + return isinstance(value, bool) # Array & Extend / Overwrite case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}: @@ -101,6 +93,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va return isinstance(value, list) and all(isinstance(item, int | float) for item in value) case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}: return isinstance(value, list) and all(isinstance(item, dict) for item in value) + case SegmentType.ARRAY_BOOLEAN if operation in {Operation.EXTEND, Operation.OVER_WRITE}: + return isinstance(value, list) and all(isinstance(item, bool) for item in value) case _: return False diff --git a/api/core/workflow/utils/condition/entities.py b/api/core/workflow/utils/condition/entities.py index 56871a15d..77a214571 100644 --- a/api/core/workflow/utils/condition/entities.py +++ b/api/core/workflow/utils/condition/entities.py @@ -45,5 +45,5 @@ class SubVariableCondition(BaseModel): class Condition(BaseModel): variable_selector: list[str] comparison_operator: SupportedComparisonOperator - value: str | Sequence[str] | None = None + value: str | Sequence[str] | bool | None = None sub_variable_condition: SubVariableCondition | None = None diff --git a/api/core/workflow/utils/condition/processor.py b/api/core/workflow/utils/condition/processor.py index 979538778..7efd1acbf 100644 --- a/api/core/workflow/utils/condition/processor.py +++ b/api/core/workflow/utils/condition/processor.py @@ -1,13 +1,27 @@ +import json from collections.abc import Sequence -from typing import Any, Literal +from typing import Any, Literal, Union from core.file import FileAttribute, file_manager from core.variables import ArrayFileSegment +from core.variables.segments import ArrayBooleanSegment, BooleanSegment from core.workflow.entities.variable_pool import VariablePool from .entities import Condition, SubCondition, SupportedComparisonOperator +def _convert_to_bool(value: Any) -> bool: + if isinstance(value, int): + return bool(value) + + if isinstance(value, str): + loaded = json.loads(value) + if isinstance(loaded, (int, bool)): + return bool(loaded) + + raise TypeError(f"unexpected value: type={type(value)}, value={value}") + + class ConditionProcessor: def process_conditions( self, @@ -48,9 +62,16 @@ class ConditionProcessor: ) else: actual_value = variable.value if variable else None - expected_value = condition.value + expected_value: str | Sequence[str] | bool | list[bool] | None = condition.value if isinstance(expected_value, str): expected_value = variable_pool.convert_template(expected_value).text + # Here we need to explicit convet the input string to boolean. + if isinstance(variable, (BooleanSegment, ArrayBooleanSegment)) and expected_value is not None: + # The following two lines is for compatibility with existing workflows. + if isinstance(expected_value, list): + expected_value = [_convert_to_bool(i) for i in expected_value] + else: + expected_value = _convert_to_bool(expected_value) input_conditions.append( { "actual_value": actual_value, @@ -77,7 +98,7 @@ def _evaluate_condition( *, operator: SupportedComparisonOperator, value: Any, - expected: str | Sequence[str] | None, + expected: Union[str, Sequence[str], bool | Sequence[bool], None], ) -> bool: match operator: case "contains": @@ -130,7 +151,7 @@ def _assert_contains(*, value: Any, expected: Any) -> bool: if not value: return False - if not isinstance(value, str | list): + if not isinstance(value, (str, list)): raise ValueError("Invalid actual value type: string or array") if expected not in value: @@ -142,7 +163,7 @@ def _assert_not_contains(*, value: Any, expected: Any) -> bool: if not value: return True - if not isinstance(value, str | list): + if not isinstance(value, (str, list)): raise ValueError("Invalid actual value type: string or array") if expected in value: @@ -178,8 +199,8 @@ def _assert_is(*, value: Any, expected: Any) -> bool: if value is None: return False - if not isinstance(value, str): - raise ValueError("Invalid actual value type: string") + if not isinstance(value, (str, bool)): + raise ValueError("Invalid actual value type: string or boolean") if value != expected: return False @@ -190,8 +211,8 @@ def _assert_is_not(*, value: Any, expected: Any) -> bool: if value is None: return False - if not isinstance(value, str): - raise ValueError("Invalid actual value type: string") + if not isinstance(value, (str, bool)): + raise ValueError("Invalid actual value type: string or boolean") if value == expected: return False @@ -214,10 +235,13 @@ def _assert_equal(*, value: Any, expected: Any) -> bool: if value is None: return False - if not isinstance(value, int | float): - raise ValueError("Invalid actual value type: number") + if not isinstance(value, (int, float, bool)): + raise ValueError("Invalid actual value type: number or boolean") - if isinstance(value, int): + # Handle boolean comparison + if isinstance(value, bool): + expected = bool(expected) + elif isinstance(value, int): expected = int(expected) else: expected = float(expected) @@ -231,10 +255,13 @@ def _assert_not_equal(*, value: Any, expected: Any) -> bool: if value is None: return False - if not isinstance(value, int | float): - raise ValueError("Invalid actual value type: number") + if not isinstance(value, (int, float, bool)): + raise ValueError("Invalid actual value type: number or boolean") - if isinstance(value, int): + # Handle boolean comparison + if isinstance(value, bool): + expected = bool(expected) + elif isinstance(value, int): expected = int(expected) else: expected = float(expected) @@ -248,7 +275,7 @@ def _assert_greater_than(*, value: Any, expected: Any) -> bool: if value is None: return False - if not isinstance(value, int | float): + if not isinstance(value, (int, float)): raise ValueError("Invalid actual value type: number") if isinstance(value, int): @@ -265,7 +292,7 @@ def _assert_less_than(*, value: Any, expected: Any) -> bool: if value is None: return False - if not isinstance(value, int | float): + if not isinstance(value, (int, float)): raise ValueError("Invalid actual value type: number") if isinstance(value, int): @@ -282,7 +309,7 @@ def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool: if value is None: return False - if not isinstance(value, int | float): + if not isinstance(value, (int, float)): raise ValueError("Invalid actual value type: number") if isinstance(value, int): @@ -299,7 +326,7 @@ def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool: if value is None: return False - if not isinstance(value, int | float): + if not isinstance(value, (int, float)): raise ValueError("Invalid actual value type: number") if isinstance(value, int): diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 39ebd009d..aa9828f3d 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -7,11 +7,13 @@ from core.file import File from core.variables.exc import VariableError from core.variables.segments import ( ArrayAnySegment, + ArrayBooleanSegment, ArrayFileSegment, ArrayNumberSegment, ArrayObjectSegment, ArraySegment, ArrayStringSegment, + BooleanSegment, FileSegment, FloatSegment, IntegerSegment, @@ -23,10 +25,12 @@ from core.variables.segments import ( from core.variables.types import SegmentType from core.variables.variables import ( ArrayAnyVariable, + ArrayBooleanVariable, ArrayFileVariable, ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, + BooleanVariable, FileVariable, FloatVariable, IntegerVariable, @@ -49,17 +53,19 @@ class TypeMismatchError(Exception): # Define the constant SEGMENT_TO_VARIABLE_MAP = { - StringSegment: StringVariable, - IntegerSegment: IntegerVariable, - FloatSegment: FloatVariable, - ObjectSegment: ObjectVariable, - FileSegment: FileVariable, - ArrayStringSegment: ArrayStringVariable, + ArrayAnySegment: ArrayAnyVariable, + ArrayBooleanSegment: ArrayBooleanVariable, + ArrayFileSegment: ArrayFileVariable, ArrayNumberSegment: ArrayNumberVariable, ArrayObjectSegment: ArrayObjectVariable, - ArrayFileSegment: ArrayFileVariable, - ArrayAnySegment: ArrayAnyVariable, + ArrayStringSegment: ArrayStringVariable, + BooleanSegment: BooleanVariable, + FileSegment: FileVariable, + FloatSegment: FloatVariable, + IntegerSegment: IntegerVariable, NoneSegment: NoneVariable, + ObjectSegment: ObjectVariable, + StringSegment: StringVariable, } @@ -99,6 +105,8 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen mapping = dict(mapping) mapping["value_type"] = SegmentType.FLOAT result = FloatVariable.model_validate(mapping) + case SegmentType.BOOLEAN: + result = BooleanVariable.model_validate(mapping) case SegmentType.NUMBER if not isinstance(value, float | int): raise VariableError(f"invalid number value {value}") case SegmentType.OBJECT if isinstance(value, dict): @@ -109,6 +117,8 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen result = ArrayNumberVariable.model_validate(mapping) case SegmentType.ARRAY_OBJECT if isinstance(value, list): result = ArrayObjectVariable.model_validate(mapping) + case SegmentType.ARRAY_BOOLEAN if isinstance(value, list): + result = ArrayBooleanVariable.model_validate(mapping) case _: raise VariableError(f"not supported value type {value_type}") if result.size > dify_config.MAX_VARIABLE_SIZE: @@ -129,6 +139,8 @@ def build_segment(value: Any, /) -> Segment: return NoneSegment() if isinstance(value, str): return StringSegment(value=value) + if isinstance(value, bool): + return BooleanSegment(value=value) if isinstance(value, int): return IntegerSegment(value=value) if isinstance(value, float): @@ -152,6 +164,8 @@ def build_segment(value: Any, /) -> Segment: return ArrayStringSegment(value=value) case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: return ArrayNumberSegment(value=value) + case SegmentType.BOOLEAN: + return ArrayBooleanSegment(value=value) case SegmentType.OBJECT: return ArrayObjectSegment(value=value) case SegmentType.FILE: @@ -170,6 +184,7 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = { SegmentType.INTEGER: IntegerSegment, SegmentType.FLOAT: FloatSegment, SegmentType.FILE: FileSegment, + SegmentType.BOOLEAN: BooleanSegment, SegmentType.OBJECT: ObjectSegment, # Array types SegmentType.ARRAY_ANY: ArrayAnySegment, @@ -177,6 +192,7 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = { SegmentType.ARRAY_NUMBER: ArrayNumberSegment, SegmentType.ARRAY_OBJECT: ArrayObjectSegment, SegmentType.ARRAY_FILE: ArrayFileSegment, + SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment, } @@ -225,6 +241,8 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: return ArrayAnySegment(value=value) elif segment_type == SegmentType.ARRAY_STRING: return ArrayStringSegment(value=value) + elif segment_type == SegmentType.ARRAY_BOOLEAN: + return ArrayBooleanSegment(value=value) elif segment_type == SegmentType.ARRAY_NUMBER: return ArrayNumberSegment(value=value) elif segment_type == SegmentType.ARRAY_OBJECT: diff --git a/api/lazy_load_class.py b/api/lazy_load_class.py new file mode 100644 index 000000000..dd3c2a16e --- /dev/null +++ b/api/lazy_load_class.py @@ -0,0 +1,11 @@ +from tests.integration_tests.utils.parent_class import ParentClass + + +class LazyLoadChildClass(ParentClass): + """Test lazy load child class for module import helper tests""" + + def __init__(self, name): + super().__init__(name) + + def get_name(self): + return self.name diff --git a/api/mypy.ini b/api/mypy.ini index 44a01068e..bd771a056 100644 --- a/api/mypy.ini +++ b/api/mypy.ini @@ -20,3 +20,6 @@ ignore_missing_imports=True [mypy-flask_restx.inputs] ignore_missing_imports=True + +[mypy-google.cloud.storage] +ignore_missing_imports=True diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py index b33a83ba7..a197b617f 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type.py +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -23,6 +23,7 @@ class TestSegmentTypeIsArrayType: SegmentType.ARRAY_NUMBER, SegmentType.ARRAY_OBJECT, SegmentType.ARRAY_FILE, + SegmentType.ARRAY_BOOLEAN, ] expected_non_array_types = [ SegmentType.INTEGER, @@ -34,6 +35,7 @@ class TestSegmentTypeIsArrayType: SegmentType.FILE, SegmentType.NONE, SegmentType.GROUP, + SegmentType.BOOLEAN, ] for seg_type in expected_array_types: diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py new file mode 100644 index 000000000..e0541280d --- /dev/null +++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py @@ -0,0 +1,729 @@ +""" +Comprehensive unit tests for SegmentType.is_valid and SegmentType._validate_array methods. + +This module provides thorough testing of the validation logic for all SegmentType values, +including edge cases, error conditions, and different ArrayValidation strategies. +""" + +from dataclasses import dataclass +from typing import Any + +import pytest + +from core.file.enums import FileTransferMethod, FileType +from core.file.models import File +from core.variables.types import ArrayValidation, SegmentType + + +def create_test_file( + file_type: FileType = FileType.DOCUMENT, + transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE, + filename: str = "test.txt", + extension: str = ".txt", + mime_type: str = "text/plain", + size: int = 1024, +) -> File: + """Factory function to create File objects for testing.""" + return File( + tenant_id="test-tenant", + type=file_type, + transfer_method=transfer_method, + filename=filename, + extension=extension, + mime_type=mime_type, + size=size, + related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None, + remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None, + storage_key="test-storage-key", + ) + + +@dataclass +class ValidationTestCase: + """Test case data structure for validation tests.""" + + segment_type: SegmentType + value: Any + expected: bool + description: str + + def get_id(self): + return self.description + + +@dataclass +class ArrayValidationTestCase: + """Test case data structure for array validation tests.""" + + segment_type: SegmentType + value: Any + array_validation: ArrayValidation + expected: bool + description: str + + def get_id(self): + return self.description + + +# Test data construction functions +def get_boolean_cases() -> list[ValidationTestCase]: + return [ + # valid values + ValidationTestCase(SegmentType.BOOLEAN, True, True, "True boolean"), + ValidationTestCase(SegmentType.BOOLEAN, False, True, "False boolean"), + # Invalid values + ValidationTestCase(SegmentType.BOOLEAN, 1, False, "Integer 1 (not boolean)"), + ValidationTestCase(SegmentType.BOOLEAN, 0, False, "Integer 0 (not boolean)"), + ValidationTestCase(SegmentType.BOOLEAN, "true", False, "String 'true'"), + ValidationTestCase(SegmentType.BOOLEAN, "false", False, "String 'false'"), + ValidationTestCase(SegmentType.BOOLEAN, None, False, "None value"), + ValidationTestCase(SegmentType.BOOLEAN, [], False, "Empty list"), + ValidationTestCase(SegmentType.BOOLEAN, {}, False, "Empty dict"), + ] + + +def get_number_cases() -> list[ValidationTestCase]: + """Get test cases for valid number values.""" + return [ + # valid values + ValidationTestCase(SegmentType.NUMBER, 42, True, "Positive integer"), + ValidationTestCase(SegmentType.NUMBER, -42, True, "Negative integer"), + ValidationTestCase(SegmentType.NUMBER, 0, True, "Zero integer"), + ValidationTestCase(SegmentType.NUMBER, 3.14, True, "Positive float"), + ValidationTestCase(SegmentType.NUMBER, -3.14, True, "Negative float"), + ValidationTestCase(SegmentType.NUMBER, 0.0, True, "Zero float"), + ValidationTestCase(SegmentType.NUMBER, float("inf"), True, "Positive infinity"), + ValidationTestCase(SegmentType.NUMBER, float("-inf"), True, "Negative infinity"), + ValidationTestCase(SegmentType.NUMBER, float("nan"), True, "float(NaN)"), + # invalid number values + ValidationTestCase(SegmentType.NUMBER, "42", False, "String number"), + ValidationTestCase(SegmentType.NUMBER, None, False, "None value"), + ValidationTestCase(SegmentType.NUMBER, [], False, "Empty list"), + ValidationTestCase(SegmentType.NUMBER, {}, False, "Empty dict"), + ValidationTestCase(SegmentType.NUMBER, "3.14", False, "String float"), + ] + + +def get_string_cases() -> list[ValidationTestCase]: + """Get test cases for valid string values.""" + return [ + # valid values + ValidationTestCase(SegmentType.STRING, "", True, "Empty string"), + ValidationTestCase(SegmentType.STRING, "hello", True, "Simple string"), + ValidationTestCase(SegmentType.STRING, "🚀", True, "Unicode emoji"), + ValidationTestCase(SegmentType.STRING, "line1\nline2", True, "Multiline string"), + # invalid values + ValidationTestCase(SegmentType.STRING, 123, False, "Integer"), + ValidationTestCase(SegmentType.STRING, 3.14, False, "Float"), + ValidationTestCase(SegmentType.STRING, True, False, "Boolean"), + ValidationTestCase(SegmentType.STRING, None, False, "None value"), + ValidationTestCase(SegmentType.STRING, [], False, "Empty list"), + ValidationTestCase(SegmentType.STRING, {}, False, "Empty dict"), + ] + + +def get_object_cases() -> list[ValidationTestCase]: + """Get test cases for valid object values.""" + return [ + # valid cases + ValidationTestCase(SegmentType.OBJECT, {}, True, "Empty dict"), + ValidationTestCase(SegmentType.OBJECT, {"key": "value"}, True, "Simple dict"), + ValidationTestCase(SegmentType.OBJECT, {"a": 1, "b": 2}, True, "Dict with numbers"), + ValidationTestCase(SegmentType.OBJECT, {"nested": {"key": "value"}}, True, "Nested dict"), + ValidationTestCase(SegmentType.OBJECT, {"list": [1, 2, 3]}, True, "Dict with list"), + ValidationTestCase(SegmentType.OBJECT, {"mixed": [1, "two", {"three": 3}]}, True, "Complex dict"), + # invalid cases + ValidationTestCase(SegmentType.OBJECT, "not a dict", False, "String"), + ValidationTestCase(SegmentType.OBJECT, 123, False, "Integer"), + ValidationTestCase(SegmentType.OBJECT, 3.14, False, "Float"), + ValidationTestCase(SegmentType.OBJECT, True, False, "Boolean"), + ValidationTestCase(SegmentType.OBJECT, None, False, "None value"), + ValidationTestCase(SegmentType.OBJECT, [], False, "Empty list"), + ValidationTestCase(SegmentType.OBJECT, [1, 2, 3], False, "List with values"), + ] + + +def get_secret_cases() -> list[ValidationTestCase]: + """Get test cases for valid secret values.""" + return [ + # valid cases + ValidationTestCase(SegmentType.SECRET, "", True, "Empty secret"), + ValidationTestCase(SegmentType.SECRET, "secret", True, "Simple secret"), + ValidationTestCase(SegmentType.SECRET, "api_key_123", True, "API key format"), + ValidationTestCase(SegmentType.SECRET, "very_long_secret_key_with_special_chars!@#", True, "Complex secret"), + # invalid cases + ValidationTestCase(SegmentType.SECRET, 123, False, "Integer"), + ValidationTestCase(SegmentType.SECRET, 3.14, False, "Float"), + ValidationTestCase(SegmentType.SECRET, True, False, "Boolean"), + ValidationTestCase(SegmentType.SECRET, None, False, "None value"), + ValidationTestCase(SegmentType.SECRET, [], False, "Empty list"), + ValidationTestCase(SegmentType.SECRET, {}, False, "Empty dict"), + ] + + +def get_file_cases() -> list[ValidationTestCase]: + """Get test cases for valid file values.""" + test_file = create_test_file() + image_file = create_test_file( + file_type=FileType.IMAGE, filename="image.jpg", extension=".jpg", mime_type="image/jpeg" + ) + remote_file = create_test_file( + transfer_method=FileTransferMethod.REMOTE_URL, filename="remote.pdf", extension=".pdf" + ) + + return [ + # valid cases + ValidationTestCase(SegmentType.FILE, test_file, True, "Document file"), + ValidationTestCase(SegmentType.FILE, image_file, True, "Image file"), + ValidationTestCase(SegmentType.FILE, remote_file, True, "Remote file"), + # invalid cases + ValidationTestCase(SegmentType.FILE, "not a file", False, "String"), + ValidationTestCase(SegmentType.FILE, 123, False, "Integer"), + ValidationTestCase(SegmentType.FILE, {"filename": "test.txt"}, False, "Dict resembling file"), + ValidationTestCase(SegmentType.FILE, None, False, "None value"), + ValidationTestCase(SegmentType.FILE, [], False, "Empty list"), + ValidationTestCase(SegmentType.FILE, True, False, "Boolean"), + ] + + +def get_none_cases() -> list[ValidationTestCase]: + """Get test cases for valid none values.""" + return [ + # valid cases + ValidationTestCase(SegmentType.NONE, None, True, "None value"), + # invalid cases + ValidationTestCase(SegmentType.NONE, "", False, "Empty string"), + ValidationTestCase(SegmentType.NONE, 0, False, "Zero integer"), + ValidationTestCase(SegmentType.NONE, 0.0, False, "Zero float"), + ValidationTestCase(SegmentType.NONE, False, False, "False boolean"), + ValidationTestCase(SegmentType.NONE, [], False, "Empty list"), + ValidationTestCase(SegmentType.NONE, {}, False, "Empty dict"), + ValidationTestCase(SegmentType.NONE, "null", False, "String 'null'"), + ] + + +def get_array_any_validation_cases() -> list[ArrayValidationTestCase]: + """Get test cases for ARRAY_ANY validation.""" + return [ + ArrayValidationTestCase( + SegmentType.ARRAY_ANY, + [1, "string", 3.14, {"key": "value"}, True], + ArrayValidation.NONE, + True, + "Mixed types with NONE validation", + ), + ArrayValidationTestCase( + SegmentType.ARRAY_ANY, + [1, "string", 3.14, {"key": "value"}, True], + ArrayValidation.FIRST, + True, + "Mixed types with FIRST validation", + ), + ArrayValidationTestCase( + SegmentType.ARRAY_ANY, + [1, "string", 3.14, {"key": "value"}, True], + ArrayValidation.ALL, + True, + "Mixed types with ALL validation", + ), + ArrayValidationTestCase( + SegmentType.ARRAY_ANY, [None, None, None], ArrayValidation.ALL, True, "All None values" + ), + ] + + +def get_array_string_validation_none_cases() -> list[ArrayValidationTestCase]: + """Get test cases for ARRAY_STRING validation with NONE strategy.""" + return [ + ArrayValidationTestCase( + SegmentType.ARRAY_STRING, + ["hello", "world"], + ArrayValidation.NONE, + True, + "Valid strings with NONE validation", + ), + ArrayValidationTestCase( + SegmentType.ARRAY_STRING, + [123, 456], + ArrayValidation.NONE, + True, + "Invalid elements with NONE validation", + ), + ArrayValidationTestCase( + SegmentType.ARRAY_STRING, + ["valid", 123, True], + ArrayValidation.NONE, + True, + "Mixed types with NONE validation", + ), + ] + + +def get_array_string_validation_first_cases() -> list[ArrayValidationTestCase]: + """Get test cases for ARRAY_STRING validation with FIRST strategy.""" + return [ + ArrayValidationTestCase( + SegmentType.ARRAY_STRING, ["hello", "world"], ArrayValidation.FIRST, True, "All valid strings" + ), + ArrayValidationTestCase( + SegmentType.ARRAY_STRING, + ["hello", 123, True], + ArrayValidation.FIRST, + True, + "First valid, others invalid", + ), + ArrayValidationTestCase( + SegmentType.ARRAY_STRING, + [123, "hello", "world"], + ArrayValidation.FIRST, + False, + "First invalid, others valid", + ), + ArrayValidationTestCase(SegmentType.ARRAY_STRING, [None, "hello"], ArrayValidation.FIRST, False, "First None"), + ] + + +def get_array_string_validation_all_cases() -> list[ArrayValidationTestCase]: + """Get test cases for ARRAY_STRING validation with ALL strategy.""" + return [ + ArrayValidationTestCase( + SegmentType.ARRAY_STRING, ["hello", "world", "test"], ArrayValidation.ALL, True, "All valid strings" + ), + ArrayValidationTestCase( + SegmentType.ARRAY_STRING, ["hello", 123, "world"], ArrayValidation.ALL, False, "One invalid element" + ), + ArrayValidationTestCase( + SegmentType.ARRAY_STRING, [123, 456, 789], ArrayValidation.ALL, False, "All invalid elements" + ), + ArrayValidationTestCase( + SegmentType.ARRAY_STRING, ["valid", None, "also_valid"], ArrayValidation.ALL, False, "Contains None" + ), + ] + + +def get_array_number_validation_cases() -> list[ArrayValidationTestCase]: + """Get test cases for ARRAY_NUMBER validation with different strategies.""" + return [ + # NONE strategy + ArrayValidationTestCase( + SegmentType.ARRAY_NUMBER, [1, 2.5, 3], ArrayValidation.NONE, True, "Valid numbers with NONE" + ), + ArrayValidationTestCase( + SegmentType.ARRAY_NUMBER, ["not", "numbers"], ArrayValidation.NONE, True, "Invalid elements with NONE" + ), + # FIRST strategy + ArrayValidationTestCase( + SegmentType.ARRAY_NUMBER, [42, "not a number"], ArrayValidation.FIRST, True, "First valid number" + ), + ArrayValidationTestCase( + SegmentType.ARRAY_NUMBER, ["not a number", 42], ArrayValidation.FIRST, False, "First invalid" + ), + ArrayValidationTestCase( + SegmentType.ARRAY_NUMBER, [3.14, 2.71, 1.41], ArrayValidation.FIRST, True, "All valid floats" + ), + # ALL strategy + ArrayValidationTestCase( + SegmentType.ARRAY_NUMBER, [1, 2, 3, 4.5], ArrayValidation.ALL, True, "All valid numbers" + ), + ArrayValidationTestCase( + SegmentType.ARRAY_NUMBER, [1, "invalid", 3], ArrayValidation.ALL, False, "One invalid element" + ), + ArrayValidationTestCase( + SegmentType.ARRAY_NUMBER, + [float("inf"), float("-inf"), float("nan")], + ArrayValidation.ALL, + True, + "Special float values", + ), + ] + + +def get_array_object_validation_cases() -> list[ArrayValidationTestCase]: + """Get test cases for ARRAY_OBJECT validation with different strategies.""" + return [ + # NONE strategy + ArrayValidationTestCase( + SegmentType.ARRAY_OBJECT, [{}, {"key": "value"}], ArrayValidation.NONE, True, "Valid objects with NONE" + ), + ArrayValidationTestCase( + SegmentType.ARRAY_OBJECT, ["not", "objects"], ArrayValidation.NONE, True, "Invalid elements with NONE" + ), + # FIRST strategy + ArrayValidationTestCase( + SegmentType.ARRAY_OBJECT, + [{"valid": "object"}, "not an object"], + ArrayValidation.FIRST, + True, + "First valid object", + ), + ArrayValidationTestCase( + SegmentType.ARRAY_OBJECT, + ["not an object", {"valid": "object"}], + ArrayValidation.FIRST, + False, + "First invalid", + ), + # ALL strategy + ArrayValidationTestCase( + SegmentType.ARRAY_OBJECT, + [{}, {"a": 1}, {"nested": {"key": "value"}}], + ArrayValidation.ALL, + True, + "All valid objects", + ), + ArrayValidationTestCase( + SegmentType.ARRAY_OBJECT, + [{"valid": "object"}, "invalid", {"another": "object"}], + ArrayValidation.ALL, + False, + "One invalid element", + ), + ] + + +def get_array_file_validation_cases() -> list[ArrayValidationTestCase]: + """Get test cases for ARRAY_FILE validation with different strategies.""" + file1 = create_test_file(filename="file1.txt") + file2 = create_test_file(filename="file2.txt") + + return [ + # NONE strategy + ArrayValidationTestCase( + SegmentType.ARRAY_FILE, [file1, file2], ArrayValidation.NONE, True, "Valid files with NONE" + ), + ArrayValidationTestCase( + SegmentType.ARRAY_FILE, ["not", "files"], ArrayValidation.NONE, True, "Invalid elements with NONE" + ), + # FIRST strategy + ArrayValidationTestCase( + SegmentType.ARRAY_FILE, [file1, "not a file"], ArrayValidation.FIRST, True, "First valid file" + ), + ArrayValidationTestCase( + SegmentType.ARRAY_FILE, ["not a file", file1], ArrayValidation.FIRST, False, "First invalid" + ), + # ALL strategy + ArrayValidationTestCase(SegmentType.ARRAY_FILE, [file1, file2], ArrayValidation.ALL, True, "All valid files"), + ArrayValidationTestCase( + SegmentType.ARRAY_FILE, [file1, "invalid", file2], ArrayValidation.ALL, False, "One invalid element" + ), + ] + + +def get_array_boolean_validation_cases() -> list[ArrayValidationTestCase]: + """Get test cases for ARRAY_BOOLEAN validation with different strategies.""" + return [ + # NONE strategy + ArrayValidationTestCase( + SegmentType.ARRAY_BOOLEAN, [True, False, True], ArrayValidation.NONE, True, "Valid booleans with NONE" + ), + ArrayValidationTestCase( + SegmentType.ARRAY_BOOLEAN, [1, 0, "true"], ArrayValidation.NONE, True, "Invalid elements with NONE" + ), + # FIRST strategy + ArrayValidationTestCase( + SegmentType.ARRAY_BOOLEAN, [True, 1, 0], ArrayValidation.FIRST, True, "First valid boolean" + ), + ArrayValidationTestCase( + SegmentType.ARRAY_BOOLEAN, [1, True, False], ArrayValidation.FIRST, False, "First invalid (integer 1)" + ), + ArrayValidationTestCase( + SegmentType.ARRAY_BOOLEAN, [0, True, False], ArrayValidation.FIRST, False, "First invalid (integer 0)" + ), + # ALL strategy + ArrayValidationTestCase( + SegmentType.ARRAY_BOOLEAN, [True, False, True, False], ArrayValidation.ALL, True, "All valid booleans" + ), + ArrayValidationTestCase( + SegmentType.ARRAY_BOOLEAN, [True, 1, False], ArrayValidation.ALL, False, "One invalid element (integer)" + ), + ArrayValidationTestCase( + SegmentType.ARRAY_BOOLEAN, + [True, "false", False], + ArrayValidation.ALL, + False, + "One invalid element (string)", + ), + ] + + +class TestSegmentTypeIsValid: + """Test suite for SegmentType.is_valid method covering all non-array types.""" + + @pytest.mark.parametrize("case", get_boolean_cases(), ids=lambda case: case.description) + def test_boolean_validation(self, case): + assert case.segment_type.is_valid(case.value) == case.expected + + @pytest.mark.parametrize("case", get_number_cases(), ids=lambda case: case.description) + def test_number_validation(self, case: ValidationTestCase): + assert case.segment_type.is_valid(case.value) == case.expected + + @pytest.mark.parametrize("case", get_string_cases(), ids=lambda case: case.description) + def test_string_validation(self, case): + assert case.segment_type.is_valid(case.value) == case.expected + + @pytest.mark.parametrize("case", get_object_cases(), ids=lambda case: case.description) + def test_object_validation(self, case): + assert case.segment_type.is_valid(case.value) == case.expected + + @pytest.mark.parametrize("case", get_secret_cases(), ids=lambda case: case.description) + def test_secret_validation(self, case): + assert case.segment_type.is_valid(case.value) == case.expected + + @pytest.mark.parametrize("case", get_file_cases(), ids=lambda case: case.description) + def test_file_validation(self, case): + assert case.segment_type.is_valid(case.value) == case.expected + + @pytest.mark.parametrize("case", get_none_cases(), ids=lambda case: case.description) + def test_none_validation_valid_cases(self, case): + assert case.segment_type.is_valid(case.value) == case.expected + + def test_unsupported_segment_type_raises_assertion_error(self): + """Test that unsupported SegmentType values raise AssertionError.""" + # GROUP is not handled in is_valid method + with pytest.raises(AssertionError, match="this statement should be unreachable"): + SegmentType.GROUP.is_valid("any value") + + +class TestSegmentTypeArrayValidation: + """Test suite for SegmentType._validate_array method and array type validation.""" + + def test_array_validation_non_list_values(self): + """Test that non-list values return False for all array types.""" + array_types = [ + SegmentType.ARRAY_ANY, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_OBJECT, + SegmentType.ARRAY_FILE, + SegmentType.ARRAY_BOOLEAN, + ] + + non_list_values = [ + "not a list", + 123, + 3.14, + True, + None, + {"key": "value"}, + create_test_file(), + ] + + for array_type in array_types: + for value in non_list_values: + assert array_type.is_valid(value) is False, f"{array_type} should reject {type(value).__name__}" + + def test_empty_array_validation(self): + """Test that empty arrays are valid for all array types regardless of validation strategy.""" + array_types = [ + SegmentType.ARRAY_ANY, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_OBJECT, + SegmentType.ARRAY_FILE, + SegmentType.ARRAY_BOOLEAN, + ] + + validation_strategies = [ArrayValidation.NONE, ArrayValidation.FIRST, ArrayValidation.ALL] + + for array_type in array_types: + for strategy in validation_strategies: + assert array_type.is_valid([], strategy) is True, ( + f"{array_type} should accept empty array with {strategy}" + ) + + @pytest.mark.parametrize("case", get_array_any_validation_cases(), ids=lambda case: case.description) + def test_array_any_validation(self, case): + """Test ARRAY_ANY validation accepts any list regardless of content.""" + assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected + + @pytest.mark.parametrize("case", get_array_string_validation_none_cases(), ids=lambda case: case.description) + def test_array_string_validation_with_none_strategy(self, case): + """Test ARRAY_STRING validation with NONE strategy (no element validation).""" + assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected + + @pytest.mark.parametrize("case", get_array_string_validation_first_cases(), ids=lambda case: case.description) + def test_array_string_validation_with_first_strategy(self, case): + """Test ARRAY_STRING validation with FIRST strategy (validate first element only).""" + assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected + + @pytest.mark.parametrize("case", get_array_string_validation_all_cases(), ids=lambda case: case.description) + def test_array_string_validation_with_all_strategy(self, case): + """Test ARRAY_STRING validation with ALL strategy (validate all elements).""" + assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected + + @pytest.mark.parametrize("case", get_array_number_validation_cases(), ids=lambda case: case.description) + def test_array_number_validation_with_different_strategies(self, case): + """Test ARRAY_NUMBER validation with different validation strategies.""" + assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected + + @pytest.mark.parametrize("case", get_array_object_validation_cases(), ids=lambda case: case.description) + def test_array_object_validation_with_different_strategies(self, case): + """Test ARRAY_OBJECT validation with different validation strategies.""" + assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected + + @pytest.mark.parametrize("case", get_array_file_validation_cases(), ids=lambda case: case.description) + def test_array_file_validation_with_different_strategies(self, case): + """Test ARRAY_FILE validation with different validation strategies.""" + assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected + + @pytest.mark.parametrize("case", get_array_boolean_validation_cases(), ids=lambda case: case.description) + def test_array_boolean_validation_with_different_strategies(self, case): + """Test ARRAY_BOOLEAN validation with different validation strategies.""" + assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected + + def test_default_array_validation_strategy(self): + """Test that default array validation strategy is FIRST.""" + # When no array_validation parameter is provided, it should default to FIRST + assert SegmentType.ARRAY_STRING.is_valid(["valid", 123]) is False # First element valid + assert SegmentType.ARRAY_STRING.is_valid([123, "valid"]) is False # First element invalid + + assert SegmentType.ARRAY_NUMBER.is_valid([42, "invalid"]) is False # First element valid + assert SegmentType.ARRAY_NUMBER.is_valid(["invalid", 42]) is False # First element invalid + + def test_array_validation_edge_cases(self): + """Test edge cases for array validation.""" + # Test with nested arrays (should be invalid for specific array types) + nested_array = [["nested", "array"], ["another", "nested"]] + + assert SegmentType.ARRAY_STRING.is_valid(nested_array, ArrayValidation.FIRST) is False + assert SegmentType.ARRAY_STRING.is_valid(nested_array, ArrayValidation.ALL) is False + assert SegmentType.ARRAY_ANY.is_valid(nested_array, ArrayValidation.ALL) is True + + # Test with very large arrays (performance consideration) + large_valid_array = ["string"] * 1000 + large_mixed_array = ["string"] * 999 + [123] # Last element invalid + + assert SegmentType.ARRAY_STRING.is_valid(large_valid_array, ArrayValidation.ALL) is True + assert SegmentType.ARRAY_STRING.is_valid(large_mixed_array, ArrayValidation.ALL) is False + assert SegmentType.ARRAY_STRING.is_valid(large_mixed_array, ArrayValidation.FIRST) is True + + +class TestSegmentTypeValidationIntegration: + """Integration tests for SegmentType validation covering interactions between methods.""" + + def test_non_array_types_ignore_array_validation_parameter(self): + """Test that non-array types ignore the array_validation parameter.""" + non_array_types = [ + SegmentType.STRING, + SegmentType.NUMBER, + SegmentType.BOOLEAN, + SegmentType.OBJECT, + SegmentType.SECRET, + SegmentType.FILE, + SegmentType.NONE, + ] + + for segment_type in non_array_types: + # Create appropriate valid value for each type + valid_value: Any + if segment_type == SegmentType.STRING: + valid_value = "test" + elif segment_type == SegmentType.NUMBER: + valid_value = 42 + elif segment_type == SegmentType.BOOLEAN: + valid_value = True + elif segment_type == SegmentType.OBJECT: + valid_value = {"key": "value"} + elif segment_type == SegmentType.SECRET: + valid_value = "secret" + elif segment_type == SegmentType.FILE: + valid_value = create_test_file() + elif segment_type == SegmentType.NONE: + valid_value = None + else: + continue # Skip unsupported types + + # All array validation strategies should give the same result + result_none = segment_type.is_valid(valid_value, ArrayValidation.NONE) + result_first = segment_type.is_valid(valid_value, ArrayValidation.FIRST) + result_all = segment_type.is_valid(valid_value, ArrayValidation.ALL) + + assert result_none == result_first == result_all == True, ( + f"{segment_type} should ignore array_validation parameter" + ) + + def test_comprehensive_type_coverage(self): + """Test that all SegmentType enum values are covered in validation tests.""" + all_segment_types = set(SegmentType) + + # Types that should be handled by is_valid method + handled_types = { + # Non-array types + SegmentType.STRING, + SegmentType.NUMBER, + SegmentType.BOOLEAN, + SegmentType.OBJECT, + SegmentType.SECRET, + SegmentType.FILE, + SegmentType.NONE, + # Array types + SegmentType.ARRAY_ANY, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_OBJECT, + SegmentType.ARRAY_FILE, + SegmentType.ARRAY_BOOLEAN, + } + + # Types that are not handled by is_valid (should raise AssertionError) + unhandled_types = { + SegmentType.GROUP, + SegmentType.INTEGER, # Handled by NUMBER validation logic + SegmentType.FLOAT, # Handled by NUMBER validation logic + } + + # Verify all types are accounted for + assert handled_types | unhandled_types == all_segment_types, "All SegmentType values should be categorized" + + # Test that handled types work correctly + for segment_type in handled_types: + if segment_type.is_array_type(): + # Test with empty array (should always be valid) + assert segment_type.is_valid([]) is True, f"{segment_type} should accept empty array" + else: + # Test with appropriate valid value + if segment_type == SegmentType.STRING: + assert segment_type.is_valid("test") is True + elif segment_type == SegmentType.NUMBER: + assert segment_type.is_valid(42) is True + elif segment_type == SegmentType.BOOLEAN: + assert segment_type.is_valid(True) is True + elif segment_type == SegmentType.OBJECT: + assert segment_type.is_valid({}) is True + elif segment_type == SegmentType.SECRET: + assert segment_type.is_valid("secret") is True + elif segment_type == SegmentType.FILE: + assert segment_type.is_valid(create_test_file()) is True + elif segment_type == SegmentType.NONE: + assert segment_type.is_valid(None) is True + + def test_boolean_vs_integer_type_distinction(self): + """Test the important distinction between boolean and integer types in validation.""" + # This tests the comment in the code about bool being a subclass of int + + # Boolean type should only accept actual booleans, not integers + assert SegmentType.BOOLEAN.is_valid(True) is True + assert SegmentType.BOOLEAN.is_valid(False) is True + assert SegmentType.BOOLEAN.is_valid(1) is False # Integer 1, not boolean + assert SegmentType.BOOLEAN.is_valid(0) is False # Integer 0, not boolean + + # Number type should accept both integers and floats, including booleans (since bool is subclass of int) + assert SegmentType.NUMBER.is_valid(42) is True + assert SegmentType.NUMBER.is_valid(3.14) is True + assert SegmentType.NUMBER.is_valid(True) is True # bool is subclass of int + assert SegmentType.NUMBER.is_valid(False) is True # bool is subclass of int + + def test_array_validation_recursive_behavior(self): + """Test that array validation correctly handles recursive validation calls.""" + # When validating array elements, _validate_array calls is_valid recursively + # with ArrayValidation.NONE to avoid infinite recursion + + # Test nested validation doesn't cause issues + nested_arrays = [["inner", "array"], ["another", "inner"]] + + # ARRAY_ANY should accept nested arrays + assert SegmentType.ARRAY_ANY.is_valid(nested_arrays, ArrayValidation.ALL) is True + + # ARRAY_STRING should reject nested arrays (first element is not a string) + assert SegmentType.ARRAY_STRING.is_valid(nested_arrays, ArrayValidation.FIRST) is False + assert SegmentType.ARRAY_STRING.is_valid(nested_arrays, ArrayValidation.ALL) is False diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/__init__.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py new file mode 100644 index 000000000..b28d1d3d0 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py @@ -0,0 +1,27 @@ +from core.variables.types import SegmentType +from core.workflow.nodes.parameter_extractor.entities import ParameterConfig + + +class TestParameterConfig: + def test_select_type(self): + data = { + "name": "yes_or_no", + "type": "select", + "options": ["yes", "no"], + "description": "a simple select made of `yes` and `no`", + "required": True, + } + + pc = ParameterConfig.model_validate(data) + assert pc.type == SegmentType.STRING + assert pc.options == data["options"] + + def test_validate_bool_type(self): + data = { + "name": "boolean", + "type": "bool", + "description": "a simple boolean parameter", + "required": True, + } + pc = ParameterConfig.model_validate(data) + assert pc.type == SegmentType.BOOLEAN diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py new file mode 100644 index 000000000..b9947d469 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py @@ -0,0 +1,567 @@ +""" +Test cases for ParameterExtractorNode._validate_result and _transform_result methods. +""" + +from dataclasses import dataclass +from typing import Any + +import pytest + +from core.model_runtime.entities import LLMMode +from core.variables.types import SegmentType +from core.workflow.nodes.llm import ModelConfig, VisionConfig +from core.workflow.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData +from core.workflow.nodes.parameter_extractor.exc import ( + InvalidNumberOfParametersError, + InvalidSelectValueError, + InvalidValueTypeError, + RequiredParameterMissingError, +) +from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from factories.variable_factory import build_segment_with_type + + +@dataclass +class ValidTestCase: + """Test case data for valid scenarios.""" + + name: str + parameters: list[ParameterConfig] + result: dict[str, Any] + + def get_name(self) -> str: + return self.name + + +@dataclass +class ErrorTestCase: + """Test case data for error scenarios.""" + + name: str + parameters: list[ParameterConfig] + result: dict[str, Any] + expected_exception: type[Exception] + expected_message: str + + def get_name(self) -> str: + return self.name + + +@dataclass +class TransformTestCase: + """Test case data for transformation scenarios.""" + + name: str + parameters: list[ParameterConfig] + input_result: dict[str, Any] + expected_result: dict[str, Any] + + def get_name(self) -> str: + return self.name + + +class TestParameterExtractorNodeMethods: + """Test helper class that provides access to the methods under test.""" + + def validate_result(self, data: ParameterExtractorNodeData, result: dict[str, Any]) -> dict[str, Any]: + """Wrapper to call _validate_result method.""" + node = ParameterExtractorNode.__new__(ParameterExtractorNode) + return node._validate_result(data=data, result=result) + + def transform_result(self, data: ParameterExtractorNodeData, result: dict[str, Any]) -> dict[str, Any]: + """Wrapper to call _transform_result method.""" + node = ParameterExtractorNode.__new__(ParameterExtractorNode) + return node._transform_result(data=data, result=result) + + +class TestValidateResult: + """Test cases for _validate_result method.""" + + @staticmethod + def get_valid_test_cases() -> list[ValidTestCase]: + """Get test cases that should pass validation.""" + return [ + ValidTestCase( + name="single_string_parameter", + parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)], + result={"name": "John"}, + ), + ValidTestCase( + name="single_number_parameter_int", + parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)], + result={"age": 25}, + ), + ValidTestCase( + name="single_number_parameter_float", + parameters=[ParameterConfig(name="price", type=SegmentType.NUMBER, description="Price", required=True)], + result={"price": 19.99}, + ), + ValidTestCase( + name="single_bool_parameter_true", + parameters=[ + ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True) + ], + result={"active": True}, + ), + ValidTestCase( + name="single_bool_parameter_true", + parameters=[ + ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True) + ], + result={"active": True}, + ), + ValidTestCase( + name="single_bool_parameter_false", + parameters=[ + ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True) + ], + result={"active": False}, + ), + ValidTestCase( + name="select_parameter_valid_option", + parameters=[ + ParameterConfig( + name="status", + type="select", # pyright: ignore[reportArgumentType] + description="Status", + required=True, + options=["active", "inactive"], + ) + ], + result={"status": "active"}, + ), + ValidTestCase( + name="array_string_parameter", + parameters=[ + ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True) + ], + result={"tags": ["tag1", "tag2", "tag3"]}, + ), + ValidTestCase( + name="array_number_parameter", + parameters=[ + ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True) + ], + result={"scores": [85, 92.5, 78]}, + ), + ValidTestCase( + name="array_object_parameter", + parameters=[ + ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True) + ], + result={"items": [{"name": "item1"}, {"name": "item2"}]}, + ), + ValidTestCase( + name="multiple_parameters", + parameters=[ + ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True), + ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True), + ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True), + ], + result={"name": "John", "age": 25, "active": True}, + ), + ValidTestCase( + name="optional_parameter_present", + parameters=[ + ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True), + ParameterConfig(name="nickname", type=SegmentType.STRING, description="Nickname", required=False), + ], + result={"name": "John", "nickname": "Johnny"}, + ), + ValidTestCase( + name="empty_array_parameter", + parameters=[ + ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True) + ], + result={"tags": []}, + ), + ] + + @staticmethod + def get_error_test_cases() -> list[ErrorTestCase]: + """Get test cases that should raise exceptions.""" + return [ + ErrorTestCase( + name="invalid_number_of_parameters_too_few", + parameters=[ + ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True), + ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True), + ], + result={"name": "John"}, + expected_exception=InvalidNumberOfParametersError, + expected_message="Invalid number of parameters", + ), + ErrorTestCase( + name="invalid_number_of_parameters_too_many", + parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)], + result={"name": "John", "age": 25}, + expected_exception=InvalidNumberOfParametersError, + expected_message="Invalid number of parameters", + ), + ErrorTestCase( + name="invalid_string_value_none", + parameters=[ + ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True), + ], + result={"name": None}, # Parameter present but None value, will trigger type check first + expected_exception=InvalidValueTypeError, + expected_message="Invalid value for parameter name, expected segment type: string, actual_type: none", + ), + ErrorTestCase( + name="invalid_select_value", + parameters=[ + ParameterConfig( + name="status", + type="select", # type: ignore + description="Status", + required=True, + options=["active", "inactive"], + ) + ], + result={"status": "pending"}, + expected_exception=InvalidSelectValueError, + expected_message="Invalid `select` value for parameter status", + ), + ErrorTestCase( + name="invalid_number_value_string", + parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)], + result={"age": "twenty-five"}, + expected_exception=InvalidValueTypeError, + expected_message="Invalid value for parameter age, expected segment type: number, actual_type: string", + ), + ErrorTestCase( + name="invalid_bool_value_string", + parameters=[ + ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True) + ], + result={"active": "yes"}, + expected_exception=InvalidValueTypeError, + expected_message=( + "Invalid value for parameter active, expected segment type: boolean, actual_type: string" + ), + ), + ErrorTestCase( + name="invalid_string_value_number", + parameters=[ + ParameterConfig( + name="description", type=SegmentType.STRING, description="Description", required=True + ) + ], + result={"description": 123}, + expected_exception=InvalidValueTypeError, + expected_message=( + "Invalid value for parameter description, expected segment type: string, actual_type: integer" + ), + ), + ErrorTestCase( + name="invalid_array_value_not_list", + parameters=[ + ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True) + ], + result={"tags": "tag1,tag2,tag3"}, + expected_exception=InvalidValueTypeError, + expected_message=( + "Invalid value for parameter tags, expected segment type: array[string], actual_type: string" + ), + ), + ErrorTestCase( + name="invalid_array_number_wrong_element_type", + parameters=[ + ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True) + ], + result={"scores": [85, "ninety-two", 78]}, + expected_exception=InvalidValueTypeError, + expected_message=( + "Invalid value for parameter scores, expected segment type: array[number], actual_type: array[any]" + ), + ), + ErrorTestCase( + name="invalid_array_string_wrong_element_type", + parameters=[ + ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True) + ], + result={"tags": ["tag1", 123, "tag3"]}, + expected_exception=InvalidValueTypeError, + expected_message=( + "Invalid value for parameter tags, expected segment type: array[string], actual_type: array[any]" + ), + ), + ErrorTestCase( + name="invalid_array_object_wrong_element_type", + parameters=[ + ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True) + ], + result={"items": [{"name": "item1"}, "item2"]}, + expected_exception=InvalidValueTypeError, + expected_message=( + "Invalid value for parameter items, expected segment type: array[object], actual_type: array[any]" + ), + ), + ErrorTestCase( + name="required_parameter_missing", + parameters=[ + ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True), + ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=False), + ], + result={"age": 25, "other": "value"}, # Missing required 'name' parameter, but has correct count + expected_exception=RequiredParameterMissingError, + expected_message="Parameter name is required", + ), + ] + + @pytest.mark.parametrize("test_case", get_valid_test_cases(), ids=ValidTestCase.get_name) + def test_validate_result_valid_cases(self, test_case): + """Test _validate_result with valid inputs.""" + helper = TestParameterExtractorNodeMethods() + + node_data = ParameterExtractorNodeData( + title="Test Node", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), + query=["test_query"], + parameters=test_case.parameters, + reasoning_mode="function_call", + vision=VisionConfig(), + ) + + result = helper.validate_result(data=node_data, result=test_case.result) + assert result == test_case.result, f"Failed for case: {test_case.name}" + + @pytest.mark.parametrize("test_case", get_error_test_cases(), ids=ErrorTestCase.get_name) + def test_validate_result_error_cases(self, test_case): + """Test _validate_result with invalid inputs that should raise exceptions.""" + helper = TestParameterExtractorNodeMethods() + + node_data = ParameterExtractorNodeData( + title="Test Node", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), + query=["test_query"], + parameters=test_case.parameters, + reasoning_mode="function_call", + vision=VisionConfig(), + ) + + with pytest.raises(test_case.expected_exception) as exc_info: + helper.validate_result(data=node_data, result=test_case.result) + + assert test_case.expected_message in str(exc_info.value), f"Failed for case: {test_case.name}" + + +class TestTransformResult: + """Test cases for _transform_result method.""" + + @staticmethod + def get_transform_test_cases() -> list[TransformTestCase]: + """Get test cases for result transformation.""" + return [ + # String parameter transformation + TransformTestCase( + name="string_parameter_present", + parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)], + input_result={"name": "John"}, + expected_result={"name": "John"}, + ), + TransformTestCase( + name="string_parameter_missing", + parameters=[ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True)], + input_result={}, + expected_result={"name": ""}, + ), + # Number parameter transformation + TransformTestCase( + name="number_parameter_int_present", + parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)], + input_result={"age": 25}, + expected_result={"age": 25}, + ), + TransformTestCase( + name="number_parameter_float_present", + parameters=[ParameterConfig(name="price", type=SegmentType.NUMBER, description="Price", required=True)], + input_result={"price": 19.99}, + expected_result={"price": 19.99}, + ), + TransformTestCase( + name="number_parameter_missing", + parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)], + input_result={}, + expected_result={"age": 0}, + ), + # Bool parameter transformation + TransformTestCase( + name="bool_parameter_missing", + parameters=[ + ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True) + ], + input_result={}, + expected_result={"active": False}, + ), + # Select parameter transformation + TransformTestCase( + name="select_parameter_present", + parameters=[ + ParameterConfig( + name="status", + type="select", # type: ignore + description="Status", + required=True, + options=["active", "inactive"], + ) + ], + input_result={"status": "active"}, + expected_result={"status": "active"}, + ), + TransformTestCase( + name="select_parameter_missing", + parameters=[ + ParameterConfig( + name="status", + type="select", # type: ignore + description="Status", + required=True, + options=["active", "inactive"], + ) + ], + input_result={}, + expected_result={"status": ""}, + ), + # Array parameter transformation - present cases + TransformTestCase( + name="array_string_parameter_present", + parameters=[ + ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True) + ], + input_result={"tags": ["tag1", "tag2"]}, + expected_result={ + "tags": build_segment_with_type(segment_type=SegmentType.ARRAY_STRING, value=["tag1", "tag2"]) + }, + ), + TransformTestCase( + name="array_number_parameter_present", + parameters=[ + ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True) + ], + input_result={"scores": [85, 92.5]}, + expected_result={ + "scores": build_segment_with_type(segment_type=SegmentType.ARRAY_NUMBER, value=[85, 92.5]) + }, + ), + TransformTestCase( + name="array_number_parameter_with_string_conversion", + parameters=[ + ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True) + ], + input_result={"scores": [85, "92.5", "78"]}, + expected_result={ + "scores": build_segment_with_type(segment_type=SegmentType.ARRAY_NUMBER, value=[85, 92.5, 78]) + }, + ), + TransformTestCase( + name="array_object_parameter_present", + parameters=[ + ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True) + ], + input_result={"items": [{"name": "item1"}, {"name": "item2"}]}, + expected_result={ + "items": build_segment_with_type( + segment_type=SegmentType.ARRAY_OBJECT, value=[{"name": "item1"}, {"name": "item2"}] + ) + }, + ), + # Array parameter transformation - missing cases + TransformTestCase( + name="array_string_parameter_missing", + parameters=[ + ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True) + ], + input_result={}, + expected_result={"tags": build_segment_with_type(segment_type=SegmentType.ARRAY_STRING, value=[])}, + ), + TransformTestCase( + name="array_number_parameter_missing", + parameters=[ + ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True) + ], + input_result={}, + expected_result={"scores": build_segment_with_type(segment_type=SegmentType.ARRAY_NUMBER, value=[])}, + ), + TransformTestCase( + name="array_object_parameter_missing", + parameters=[ + ParameterConfig(name="items", type=SegmentType.ARRAY_OBJECT, description="Items", required=True) + ], + input_result={}, + expected_result={"items": build_segment_with_type(segment_type=SegmentType.ARRAY_OBJECT, value=[])}, + ), + # Multiple parameters transformation + TransformTestCase( + name="multiple_parameters_mixed", + parameters=[ + ParameterConfig(name="name", type=SegmentType.STRING, description="Name", required=True), + ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True), + ParameterConfig(name="active", type=SegmentType.BOOLEAN, description="Active", required=True), + ParameterConfig(name="tags", type=SegmentType.ARRAY_STRING, description="Tags", required=True), + ], + input_result={"name": "John", "age": 25}, + expected_result={ + "name": "John", + "age": 25, + "active": False, + "tags": build_segment_with_type(segment_type=SegmentType.ARRAY_STRING, value=[]), + }, + ), + # Number parameter transformation with string conversion + TransformTestCase( + name="number_parameter_string_to_float", + parameters=[ParameterConfig(name="price", type=SegmentType.NUMBER, description="Price", required=True)], + input_result={"price": "19.99"}, + expected_result={"price": 19.99}, # String not converted, falls back to default + ), + TransformTestCase( + name="number_parameter_string_to_int", + parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)], + input_result={"age": "25"}, + expected_result={"age": 25}, # String not converted, falls back to default + ), + TransformTestCase( + name="number_parameter_invalid_string", + parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)], + input_result={"age": "invalid_number"}, + expected_result={"age": 0}, # Invalid string conversion fails, falls back to default + ), + TransformTestCase( + name="number_parameter_non_string_non_number", + parameters=[ParameterConfig(name="age", type=SegmentType.NUMBER, description="Age", required=True)], + input_result={"age": ["not_a_number"]}, # Non-string, non-number value + expected_result={"age": 0}, # Falls back to default + ), + TransformTestCase( + name="array_number_parameter_with_invalid_string_conversion", + parameters=[ + ParameterConfig(name="scores", type=SegmentType.ARRAY_NUMBER, description="Scores", required=True) + ], + input_result={"scores": [85, "invalid", "78"]}, + expected_result={ + "scores": build_segment_with_type( + segment_type=SegmentType.ARRAY_NUMBER, value=[85, 78] + ) # Invalid string skipped + }, + ), + ] + + @pytest.mark.parametrize("test_case", get_transform_test_cases(), ids=TransformTestCase.get_name) + def test_transform_result_cases(self, test_case): + """Test _transform_result with various inputs.""" + helper = TestParameterExtractorNodeMethods() + + node_data = ParameterExtractorNodeData( + title="Test Node", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), + query=["test_query"], + parameters=test_case.parameters, + reasoning_mode="function_call", + vision=VisionConfig(), + ) + + result = helper.transform_result(data=node_data, result=test_case.input_result) + assert result == test_case.expected_result, ( + f"Failed for case: {test_case.name}. Expected: {test_case.expected_result}, Got: {result}" + ) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 8383aee0e..36a6fbb53 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -2,6 +2,8 @@ import time import uuid from unittest.mock import MagicMock, Mock +import pytest + from core.app.entities.app_invoke_entities import InvokeFrom from core.file import File, FileTransferMethod, FileType from core.variables import ArrayFileSegment @@ -272,3 +274,220 @@ def test_array_file_contains_file_name(): assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs is not None assert result.outputs["result"] is True + + +def _get_test_conditions() -> list: + conditions = [ + # Test boolean "is" operator + {"comparison_operator": "is", "variable_selector": ["start", "bool_true"], "value": "true"}, + # Test boolean "is not" operator + {"comparison_operator": "is not", "variable_selector": ["start", "bool_false"], "value": "true"}, + # Test boolean "=" operator + {"comparison_operator": "=", "variable_selector": ["start", "bool_true"], "value": "1"}, + # Test boolean "≠" operator + {"comparison_operator": "≠", "variable_selector": ["start", "bool_false"], "value": "1"}, + # Test boolean "not null" operator + {"comparison_operator": "not null", "variable_selector": ["start", "bool_true"]}, + # Test boolean array "contains" operator + {"comparison_operator": "contains", "variable_selector": ["start", "bool_array"], "value": "true"}, + # Test boolean "in" operator + { + "comparison_operator": "in", + "variable_selector": ["start", "bool_true"], + "value": ["true", "false"], + }, + ] + return [Condition.model_validate(i) for i in conditions] + + +def _get_condition_test_id(c: Condition): + return c.comparison_operator + + +@pytest.mark.parametrize("condition", _get_test_conditions(), ids=_get_condition_test_id) +def test_execute_if_else_boolean_conditions(condition: Condition): + """Test IfElseNode with boolean conditions using various operators""" + graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]} + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool with boolean values + pool = VariablePool( + system_variables=SystemVariable(files=[], user_id="aaa"), + ) + pool.add(["start", "bool_true"], True) + pool.add(["start", "bool_false"], False) + pool.add(["start", "bool_array"], [True, False, True]) + pool.add(["start", "mixed_array"], [True, "false", 1, 0]) + + node_data = { + "title": "Boolean Test", + "type": "if-else", + "logical_operator": "and", + "conditions": [condition.model_dump()], + } + node = IfElseNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={"id": "if-else", "data": node_data}, + ) + node.init_node_data(node_data) + + # Mock db.session.close() + db.session.close = MagicMock() + + # execute node + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs["result"] is True + + +def test_execute_if_else_boolean_false_conditions(): + """Test IfElseNode with boolean conditions that should evaluate to false""" + graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]} + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool with boolean values + pool = VariablePool( + system_variables=SystemVariable(files=[], user_id="aaa"), + ) + pool.add(["start", "bool_true"], True) + pool.add(["start", "bool_false"], False) + pool.add(["start", "bool_array"], [True, False, True]) + + node_data = { + "title": "Boolean False Test", + "type": "if-else", + "logical_operator": "or", + "conditions": [ + # Test boolean "is" operator (should be false) + {"comparison_operator": "is", "variable_selector": ["start", "bool_true"], "value": "false"}, + # Test boolean "=" operator (should be false) + {"comparison_operator": "=", "variable_selector": ["start", "bool_false"], "value": "1"}, + # Test boolean "not contains" operator (should be false) + { + "comparison_operator": "not contains", + "variable_selector": ["start", "bool_array"], + "value": "true", + }, + ], + } + + node = IfElseNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "id": "if-else", + "data": node_data, + }, + ) + node.init_node_data(node_data) + + # Mock db.session.close() + db.session.close = MagicMock() + + # execute node + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs["result"] is False + + +def test_execute_if_else_boolean_cases_structure(): + """Test IfElseNode with boolean conditions using the new cases structure""" + graph_config = {"edges": [], "nodes": [{"data": {"type": "start"}, "id": "start"}]} + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool with boolean values + pool = VariablePool( + system_variables=SystemVariable(files=[], user_id="aaa"), + ) + pool.add(["start", "bool_true"], True) + pool.add(["start", "bool_false"], False) + + node_data = { + "title": "Boolean Cases Test", + "type": "if-else", + "cases": [ + { + "case_id": "true", + "logical_operator": "and", + "conditions": [ + { + "comparison_operator": "is", + "variable_selector": ["start", "bool_true"], + "value": "true", + }, + { + "comparison_operator": "is not", + "variable_selector": ["start", "bool_false"], + "value": "true", + }, + ], + } + ], + } + node = IfElseNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={"id": "if-else", "data": node_data}, + ) + node.init_node_data(node_data) + + # Mock db.session.close() + db.session.close = MagicMock() + + # execute node + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs["result"] is True + assert result.outputs["selected_case_id"] == "true" diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 5fc9eab2d..d4d6aa038 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -11,7 +11,8 @@ from core.workflow.nodes.list_operator.entities import ( FilterCondition, Limit, ListOperatorNodeData, - OrderBy, + Order, + OrderByConfig, ) from core.workflow.nodes.list_operator.exc import InvalidKeyError from core.workflow.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func @@ -27,7 +28,7 @@ def list_operator_node(): FilterCondition(key="type", comparison_operator="in", value=[FileType.IMAGE, FileType.DOCUMENT]) ], ), - "order_by": OrderBy(enabled=False, value="asc"), + "order_by": OrderByConfig(enabled=False, value=Order.ASC), "limit": Limit(enabled=False, size=0), "extract_by": ExtractConfig(enabled=False, serial="1"), "title": "Test Title", diff --git a/api/tests/unit_tests/factories/test_variable_factory.py b/api/tests/unit_tests/factories/test_variable_factory.py index 4f2542a32..2a193ef2d 100644 --- a/api/tests/unit_tests/factories/test_variable_factory.py +++ b/api/tests/unit_tests/factories/test_variable_factory.py @@ -24,16 +24,18 @@ from core.variables.segments import ( ArrayNumberSegment, ArrayObjectSegment, ArrayStringSegment, + BooleanSegment, FileSegment, FloatSegment, IntegerSegment, NoneSegment, ObjectSegment, + Segment, StringSegment, ) from core.variables.types import SegmentType from factories import variable_factory -from factories.variable_factory import TypeMismatchError, build_segment_with_type +from factories.variable_factory import TypeMismatchError, build_segment, build_segment_with_type def test_string_variable(): @@ -139,6 +141,26 @@ def test_array_number_variable(): assert isinstance(variable.value[1], float) +def test_build_segment_scalar_values(): + @dataclass + class TestCase: + value: Any + expected: Segment + description: str + + cases = [ + TestCase( + value=True, + expected=BooleanSegment(value=True), + description="build_segment with boolean should yield BooleanSegment", + ) + ] + + for idx, c in enumerate(cases, 1): + seg = build_segment(c.value) + assert seg == c.expected, f"Test case {idx} failed: {c.description}" + + def test_array_object_variable(): mapping = { "id": str(uuid4()), @@ -847,15 +869,22 @@ class TestBuildSegmentValueErrors: f"but got: {error_message}" ) - def test_build_segment_boolean_type_note(self): - """Note: Boolean values are actually handled as integers in Python, so they don't raise ValueError.""" - # Boolean values in Python are subclasses of int, so they get processed as integers - # True becomes IntegerSegment(value=1) and False becomes IntegerSegment(value=0) + def test_build_segment_boolean_type(self): + """Test that Boolean values are correctly handled as boolean type, not integers.""" + # Boolean values should now be processed as BooleanSegment, not IntegerSegment + # This is because the bool check now comes before the int check in build_segment true_segment = variable_factory.build_segment(True) false_segment = variable_factory.build_segment(False) - # Verify they are processed as integers, not as errors - assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1" - assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0" - assert true_segment.value_type == SegmentType.INTEGER - assert false_segment.value_type == SegmentType.INTEGER + # Verify they are processed as booleans, not integers + assert true_segment.value is True, "Test case 1 (boolean_true): Expected True to be processed as boolean True" + assert false_segment.value is False, ( + "Test case 2 (boolean_false): Expected False to be processed as boolean False" + ) + assert true_segment.value_type == SegmentType.BOOLEAN + assert false_segment.value_type == SegmentType.BOOLEAN + + # Test array of booleans + bool_array_segment = variable_factory.build_segment([True, False, True]) + assert bool_array_segment.value_type == SegmentType.ARRAY_BOOLEAN + assert bool_array_segment.value == [True, False, True] diff --git a/simple_boolean_test.py b/simple_boolean_test.py new file mode 100644 index 000000000..832efd425 --- /dev/null +++ b/simple_boolean_test.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +""" +Simple test to verify boolean classes can be imported correctly. +""" + +import sys +import os + +# Add the api directory to the Python path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "api")) + +try: + # Test that we can import the boolean classes + from core.variables.segments import BooleanSegment, ArrayBooleanSegment + from core.variables.variables import BooleanVariable, ArrayBooleanVariable + from core.variables.types import SegmentType + + print("✅ Successfully imported BooleanSegment") + print("✅ Successfully imported ArrayBooleanSegment") + print("✅ Successfully imported BooleanVariable") + print("✅ Successfully imported ArrayBooleanVariable") + print("✅ Successfully imported SegmentType") + + # Test that the segment types exist + print(f"✅ SegmentType.BOOLEAN = {SegmentType.BOOLEAN}") + print(f"✅ SegmentType.ARRAY_BOOLEAN = {SegmentType.ARRAY_BOOLEAN}") + + # Test creating boolean segments directly + bool_seg = BooleanSegment(value=True) + print(f"✅ Created BooleanSegment: {bool_seg}") + print(f" Value type: {bool_seg.value_type}") + print(f" Value: {bool_seg.value}") + + array_bool_seg = ArrayBooleanSegment(value=[True, False, True]) + print(f"✅ Created ArrayBooleanSegment: {array_bool_seg}") + print(f" Value type: {array_bool_seg.value_type}") + print(f" Value: {array_bool_seg.value}") + + print("\n🎉 All boolean class imports and basic functionality work correctly!") + +except ImportError as e: + print(f"❌ Import error: {e}") +except Exception as e: + print(f"❌ Error: {e}") + import traceback + + traceback.print_exc() diff --git a/test_boolean_conditions.py b/test_boolean_conditions.py new file mode 100644 index 000000000..776fe5509 --- /dev/null +++ b/test_boolean_conditions.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 +""" +Simple test script to verify boolean condition support in IfElseNode +""" + +import sys +import os + +# Add the api directory to the Python path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "api")) + +from core.workflow.utils.condition.processor import ( + ConditionProcessor, + _evaluate_condition, +) + + +def test_boolean_conditions(): + """Test boolean condition evaluation""" + print("Testing boolean condition support...") + + # Test boolean "is" operator + result = _evaluate_condition(value=True, operator="is", expected="true") + assert result == True, f"Expected True, got {result}" + print("✓ Boolean 'is' with True value passed") + + result = _evaluate_condition(value=False, operator="is", expected="false") + assert result == True, f"Expected True, got {result}" + print("✓ Boolean 'is' with False value passed") + + # Test boolean "is not" operator + result = _evaluate_condition(value=True, operator="is not", expected="false") + assert result == True, f"Expected True, got {result}" + print("✓ Boolean 'is not' with True value passed") + + result = _evaluate_condition(value=False, operator="is not", expected="true") + assert result == True, f"Expected True, got {result}" + print("✓ Boolean 'is not' with False value passed") + + # Test boolean "=" operator + result = _evaluate_condition(value=True, operator="=", expected="1") + assert result == True, f"Expected True, got {result}" + print("✓ Boolean '=' with True=1 passed") + + result = _evaluate_condition(value=False, operator="=", expected="0") + assert result == True, f"Expected True, got {result}" + print("✓ Boolean '=' with False=0 passed") + + # Test boolean "≠" operator + result = _evaluate_condition(value=True, operator="≠", expected="0") + assert result == True, f"Expected True, got {result}" + print("✓ Boolean '≠' with True≠0 passed") + + result = _evaluate_condition(value=False, operator="≠", expected="1") + assert result == True, f"Expected True, got {result}" + print("✓ Boolean '≠' with False≠1 passed") + + # Test boolean "in" operator + result = _evaluate_condition(value=True, operator="in", expected=["true", "false"]) + assert result == True, f"Expected True, got {result}" + print("✓ Boolean 'in' with True in array passed") + + result = _evaluate_condition(value=False, operator="in", expected=["true", "false"]) + assert result == True, f"Expected True, got {result}" + print("✓ Boolean 'in' with False in array passed") + + # Test boolean "not in" operator + result = _evaluate_condition(value=True, operator="not in", expected=["false", "0"]) + assert result == True, f"Expected True, got {result}" + print("✓ Boolean 'not in' with True not in [false, 0] passed") + + # Test boolean "null" and "not null" operators + result = _evaluate_condition(value=True, operator="not null", expected=None) + assert result == True, f"Expected True, got {result}" + print("✓ Boolean 'not null' with True passed") + + result = _evaluate_condition(value=False, operator="not null", expected=None) + assert result == True, f"Expected True, got {result}" + print("✓ Boolean 'not null' with False passed") + + print("\n🎉 All boolean condition tests passed!") + + +def test_backward_compatibility(): + """Test that existing string and number conditions still work""" + print("\nTesting backward compatibility...") + + # Test string conditions + result = _evaluate_condition(value="hello", operator="is", expected="hello") + assert result == True, f"Expected True, got {result}" + print("✓ String 'is' condition still works") + + result = _evaluate_condition(value="hello", operator="contains", expected="ell") + assert result == True, f"Expected True, got {result}" + print("✓ String 'contains' condition still works") + + # Test number conditions + result = _evaluate_condition(value=42, operator="=", expected="42") + assert result == True, f"Expected True, got {result}" + print("✓ Number '=' condition still works") + + result = _evaluate_condition(value=42, operator=">", expected="40") + assert result == True, f"Expected True, got {result}" + print("✓ Number '>' condition still works") + + print("✓ Backward compatibility maintained!") + + +if __name__ == "__main__": + try: + test_boolean_conditions() + test_backward_compatibility() + print( + "\n✅ All tests passed! Boolean support has been successfully added to IfElseNode." + ) + except Exception as e: + print(f"\n❌ Test failed: {e}") + sys.exit(1) diff --git a/test_boolean_contains_fix.py b/test_boolean_contains_fix.py new file mode 100644 index 000000000..88276e555 --- /dev/null +++ b/test_boolean_contains_fix.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 + +""" +Test script to verify the boolean array comparison fix in condition processor. +""" + +import sys +import os + +# Add the api directory to the Python path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "api")) + +from core.workflow.utils.condition.processor import ( + _assert_contains, + _assert_not_contains, +) + + +def test_boolean_array_contains(): + """Test that boolean arrays work correctly with string comparisons.""" + + # Test case 1: Boolean array [True, False, True] contains "true" + bool_array = [True, False, True] + + # Should return True because "true" converts to True and True is in the array + result1 = _assert_contains(value=bool_array, expected="true") + print(f"Test 1 - [True, False, True] contains 'true': {result1}") + assert result1 == True, "Expected True but got False" + + # Should return True because "false" converts to False and False is in the array + result2 = _assert_contains(value=bool_array, expected="false") + print(f"Test 2 - [True, False, True] contains 'false': {result2}") + assert result2 == True, "Expected True but got False" + + # Test case 2: Boolean array [True, True] does not contain "false" + bool_array2 = [True, True] + result3 = _assert_contains(value=bool_array2, expected="false") + print(f"Test 3 - [True, True] contains 'false': {result3}") + assert result3 == False, "Expected False but got True" + + # Test case 3: Test not_contains + result4 = _assert_not_contains(value=bool_array2, expected="false") + print(f"Test 4 - [True, True] not contains 'false': {result4}") + assert result4 == True, "Expected True but got False" + + result5 = _assert_not_contains(value=bool_array, expected="true") + print(f"Test 5 - [True, False, True] not contains 'true': {result5}") + assert result5 == False, "Expected False but got True" + + # Test case 4: Test with different string representations + result6 = _assert_contains( + value=bool_array, expected="1" + ) # "1" should convert to True + print(f"Test 6 - [True, False, True] contains '1': {result6}") + assert result6 == True, "Expected True but got False" + + result7 = _assert_contains( + value=bool_array, expected="0" + ) # "0" should convert to False + print(f"Test 7 - [True, False, True] contains '0': {result7}") + assert result7 == True, "Expected True but got False" + + print("\n✅ All boolean array comparison tests passed!") + + +if __name__ == "__main__": + test_boolean_array_contains() diff --git a/test_boolean_factory.py b/test_boolean_factory.py new file mode 100644 index 000000000..00e250b6d --- /dev/null +++ b/test_boolean_factory.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +""" +Simple test script to verify boolean type inference in variable factory. +""" + +import sys +import os + +# Add the api directory to the Python path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "api")) + +try: + from factories.variable_factory import build_segment, segment_to_variable + from core.variables.segments import BooleanSegment, ArrayBooleanSegment + from core.variables.variables import BooleanVariable, ArrayBooleanVariable + from core.variables.types import SegmentType + + def test_boolean_inference(): + print("Testing boolean type inference...") + + # Test single boolean values + true_segment = build_segment(True) + false_segment = build_segment(False) + + print(f"True value: {true_segment}") + print(f"Type: {type(true_segment)}") + print(f"Value type: {true_segment.value_type}") + print(f"Is BooleanSegment: {isinstance(true_segment, BooleanSegment)}") + + print(f"\nFalse value: {false_segment}") + print(f"Type: {type(false_segment)}") + print(f"Value type: {false_segment.value_type}") + print(f"Is BooleanSegment: {isinstance(false_segment, BooleanSegment)}") + + # Test array of booleans + bool_array_segment = build_segment([True, False, True]) + print(f"\nBoolean array: {bool_array_segment}") + print(f"Type: {type(bool_array_segment)}") + print(f"Value type: {bool_array_segment.value_type}") + print( + f"Is ArrayBooleanSegment: {isinstance(bool_array_segment, ArrayBooleanSegment)}" + ) + + # Test empty boolean array + empty_bool_array = build_segment([]) + print(f"\nEmpty array: {empty_bool_array}") + print(f"Type: {type(empty_bool_array)}") + print(f"Value type: {empty_bool_array.value_type}") + + # Test segment to variable conversion + bool_var = segment_to_variable( + segment=true_segment, selector=["test", "bool_var"], name="test_boolean" + ) + print(f"\nBoolean variable: {bool_var}") + print(f"Type: {type(bool_var)}") + print(f"Is BooleanVariable: {isinstance(bool_var, BooleanVariable)}") + + array_bool_var = segment_to_variable( + segment=bool_array_segment, + selector=["test", "array_bool_var"], + name="test_array_boolean", + ) + print(f"\nArray boolean variable: {array_bool_var}") + print(f"Type: {type(array_bool_var)}") + print( + f"Is ArrayBooleanVariable: {isinstance(array_bool_var, ArrayBooleanVariable)}" + ) + + # Test that bool comes before int (critical ordering) + print(f"\nTesting bool vs int precedence:") + print(f"True is instance of bool: {isinstance(True, bool)}") + print(f"True is instance of int: {isinstance(True, int)}") + print(f"False is instance of bool: {isinstance(False, bool)}") + print(f"False is instance of int: {isinstance(False, int)}") + + # Verify that boolean values are correctly inferred as boolean, not int + assert true_segment.value_type == SegmentType.BOOLEAN, ( + "True should be inferred as BOOLEAN" + ) + assert false_segment.value_type == SegmentType.BOOLEAN, ( + "False should be inferred as BOOLEAN" + ) + assert bool_array_segment.value_type == SegmentType.ARRAY_BOOLEAN, ( + "Boolean array should be inferred as ARRAY_BOOLEAN" + ) + + print("\n✅ All boolean inference tests passed!") + + if __name__ == "__main__": + test_boolean_inference() + +except ImportError as e: + print(f"Import error: {e}") + print("Make sure you're running this from the correct directory") +except Exception as e: + print(f"Error: {e}") + import traceback + + traceback.print_exc() diff --git a/test_boolean_variable_assigner.py b/test_boolean_variable_assigner.py new file mode 100644 index 000000000..388266760 --- /dev/null +++ b/test_boolean_variable_assigner.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 +""" +Test script to verify boolean support in VariableAssigner node +""" + +import sys +import os + +# Add the api directory to the Python path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "api")) + +from core.variables import SegmentType +from core.workflow.nodes.variable_assigner.v2.helpers import ( + is_operation_supported, + is_constant_input_supported, + is_input_value_valid, +) +from core.workflow.nodes.variable_assigner.v2.enums import Operation +from core.workflow.nodes.variable_assigner.v2.constants import EMPTY_VALUE_MAPPING + + +def test_boolean_operation_support(): + """Test that boolean types support the correct operations""" + print("Testing boolean operation support...") + + # Boolean should support SET, OVER_WRITE, and CLEAR + assert is_operation_supported( + variable_type=SegmentType.BOOLEAN, operation=Operation.SET + ) + assert is_operation_supported( + variable_type=SegmentType.BOOLEAN, operation=Operation.OVER_WRITE + ) + assert is_operation_supported( + variable_type=SegmentType.BOOLEAN, operation=Operation.CLEAR + ) + + # Boolean should NOT support arithmetic operations + assert not is_operation_supported( + variable_type=SegmentType.BOOLEAN, operation=Operation.ADD + ) + assert not is_operation_supported( + variable_type=SegmentType.BOOLEAN, operation=Operation.SUBTRACT + ) + assert not is_operation_supported( + variable_type=SegmentType.BOOLEAN, operation=Operation.MULTIPLY + ) + assert not is_operation_supported( + variable_type=SegmentType.BOOLEAN, operation=Operation.DIVIDE + ) + + # Boolean should NOT support array operations + assert not is_operation_supported( + variable_type=SegmentType.BOOLEAN, operation=Operation.APPEND + ) + assert not is_operation_supported( + variable_type=SegmentType.BOOLEAN, operation=Operation.EXTEND + ) + + print("✓ Boolean operation support tests passed") + + +def test_array_boolean_operation_support(): + """Test that array boolean types support the correct operations""" + print("Testing array boolean operation support...") + + # Array boolean should support APPEND, EXTEND, SET, OVER_WRITE, CLEAR + assert is_operation_supported( + variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.APPEND + ) + assert is_operation_supported( + variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.EXTEND + ) + assert is_operation_supported( + variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.OVER_WRITE + ) + assert is_operation_supported( + variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.CLEAR + ) + assert is_operation_supported( + variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.REMOVE_FIRST + ) + assert is_operation_supported( + variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.REMOVE_LAST + ) + + # Array boolean should NOT support arithmetic operations + assert not is_operation_supported( + variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.ADD + ) + assert not is_operation_supported( + variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.SUBTRACT + ) + assert not is_operation_supported( + variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.MULTIPLY + ) + assert not is_operation_supported( + variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.DIVIDE + ) + + print("✓ Array boolean operation support tests passed") + + +def test_boolean_constant_input_support(): + """Test that boolean types support constant input for correct operations""" + print("Testing boolean constant input support...") + + # Boolean should support constant input for SET and OVER_WRITE + assert is_constant_input_supported( + variable_type=SegmentType.BOOLEAN, operation=Operation.SET + ) + assert is_constant_input_supported( + variable_type=SegmentType.BOOLEAN, operation=Operation.OVER_WRITE + ) + + # Boolean should NOT support constant input for arithmetic operations + assert not is_constant_input_supported( + variable_type=SegmentType.BOOLEAN, operation=Operation.ADD + ) + + print("✓ Boolean constant input support tests passed") + + +def test_boolean_input_validation(): + """Test that boolean input validation works correctly""" + print("Testing boolean input validation...") + + # Boolean values should be valid for boolean type + assert is_input_value_valid( + variable_type=SegmentType.BOOLEAN, operation=Operation.SET, value=True + ) + assert is_input_value_valid( + variable_type=SegmentType.BOOLEAN, operation=Operation.SET, value=False + ) + assert is_input_value_valid( + variable_type=SegmentType.BOOLEAN, operation=Operation.OVER_WRITE, value=True + ) + + # Non-boolean values should be invalid for boolean type + assert not is_input_value_valid( + variable_type=SegmentType.BOOLEAN, operation=Operation.SET, value="true" + ) + assert not is_input_value_valid( + variable_type=SegmentType.BOOLEAN, operation=Operation.SET, value=1 + ) + assert not is_input_value_valid( + variable_type=SegmentType.BOOLEAN, operation=Operation.SET, value=0 + ) + + print("✓ Boolean input validation tests passed") + + +def test_array_boolean_input_validation(): + """Test that array boolean input validation works correctly""" + print("Testing array boolean input validation...") + + # Boolean values should be valid for array boolean append + assert is_input_value_valid( + variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.APPEND, value=True + ) + assert is_input_value_valid( + variable_type=SegmentType.ARRAY_BOOLEAN, operation=Operation.APPEND, value=False + ) + + # Boolean arrays should be valid for extend/overwrite + assert is_input_value_valid( + variable_type=SegmentType.ARRAY_BOOLEAN, + operation=Operation.EXTEND, + value=[True, False, True], + ) + assert is_input_value_valid( + variable_type=SegmentType.ARRAY_BOOLEAN, + operation=Operation.OVER_WRITE, + value=[False, False], + ) + + # Non-boolean values should be invalid + assert not is_input_value_valid( + variable_type=SegmentType.ARRAY_BOOLEAN, + operation=Operation.APPEND, + value="true", + ) + assert not is_input_value_valid( + variable_type=SegmentType.ARRAY_BOOLEAN, + operation=Operation.EXTEND, + value=[True, "false"], + ) + + print("✓ Array boolean input validation tests passed") + + +def test_empty_value_mapping(): + """Test that empty value mapping includes boolean types""" + print("Testing empty value mapping...") + + # Check that boolean types have correct empty values + assert SegmentType.BOOLEAN in EMPTY_VALUE_MAPPING + assert EMPTY_VALUE_MAPPING[SegmentType.BOOLEAN] is False + + assert SegmentType.ARRAY_BOOLEAN in EMPTY_VALUE_MAPPING + assert EMPTY_VALUE_MAPPING[SegmentType.ARRAY_BOOLEAN] == [] + + print("✓ Empty value mapping tests passed") + + +def main(): + """Run all tests""" + print("Running VariableAssigner boolean support tests...\n") + + try: + test_boolean_operation_support() + test_array_boolean_operation_support() + test_boolean_constant_input_support() + test_boolean_input_validation() + test_array_boolean_input_validation() + test_empty_value_mapping() + + print( + "\n🎉 All tests passed! Boolean support has been successfully added to VariableAssigner." + ) + + except Exception as e: + print(f"\n❌ Test failed: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/web/app/components/app/configuration/config-var/config-modal/config.ts b/web/app/components/app/configuration/config-var/config-modal/config.ts new file mode 100644 index 000000000..0de8f7930 --- /dev/null +++ b/web/app/components/app/configuration/config-var/config-modal/config.ts @@ -0,0 +1,24 @@ +export const jsonObjectWrap = { + type: 'object', + properties: {}, + required: [], + additionalProperties: true, +} + +export const jsonConfigPlaceHolder = JSON.stringify( + { + foo: { + type: 'string', + }, + bar: { + type: 'object', + properties: { + sub: { + type: 'number', + }, + }, + required: [], + additionalProperties: true, + }, + }, null, 2, +) diff --git a/web/app/components/app/configuration/config-var/config-modal/field.tsx b/web/app/components/app/configuration/config-var/config-modal/field.tsx index 78bd2d9f7..b24e0be6c 100644 --- a/web/app/components/app/configuration/config-var/config-modal/field.tsx +++ b/web/app/components/app/configuration/config-var/config-modal/field.tsx @@ -2,21 +2,28 @@ import type { FC } from 'react' import React from 'react' import cn from '@/utils/classnames' +import { useTranslation } from 'react-i18next' type Props = { className?: string title: string + isOptional?: boolean children: React.JSX.Element } const Field: FC = ({ className, title, + isOptional, children, }) => { + const { t } = useTranslation() return (
-
{title}
+
+ {title} + {isOptional && ({t('appDebug.variableConfig.optional')})} +
{children}
) diff --git a/web/app/components/app/configuration/config-var/config-modal/index.tsx b/web/app/components/app/configuration/config-var/config-modal/index.tsx index 4ba451452..cecc076fe 100644 --- a/web/app/components/app/configuration/config-var/config-modal/index.tsx +++ b/web/app/components/app/configuration/config-var/config-modal/index.tsx @@ -1,13 +1,12 @@ 'use client' import type { ChangeEvent, FC } from 'react' -import React, { useCallback, useEffect, useRef, useState } from 'react' +import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import produce from 'immer' import ModalFoot from '../modal-foot' import ConfigSelect from '../config-select' import ConfigString from '../config-string' -import SelectTypeItem from '../select-type-item' import Field from './field' import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' @@ -20,7 +19,13 @@ import FileUploadSetting from '@/app/components/workflow/nodes/_base/components/ import Checkbox from '@/app/components/base/checkbox' import { DEFAULT_FILE_UPLOAD_SETTING } from '@/app/components/workflow/constants' import { DEFAULT_VALUE_MAX_LEN } from '@/config' +import type { Item as SelectItem } from './type-select' +import TypeSelector from './type-select' import { SimpleSelect } from '@/app/components/base/select' +import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' +import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' +import { jsonConfigPlaceHolder, jsonObjectWrap } from './config' +import { useStore as useAppStore } from '@/app/components/app/store' import Textarea from '@/app/components/base/textarea' import { FileUploaderInAttachmentWrapper } from '@/app/components/base/file-uploader' import { TransferMethod } from '@/types/app' @@ -51,6 +56,20 @@ const ConfigModal: FC = ({ const [tempPayload, setTempPayload] = useState(payload || getNewVarInWorkflow('') as any) const { type, label, variable, options, max_length } = tempPayload const modalRef = useRef(null) + const appDetail = useAppStore(state => state.appDetail) + const isBasicApp = appDetail?.mode !== 'advanced-chat' && appDetail?.mode !== 'workflow' + const isSupportJSON = false + const jsonSchemaStr = useMemo(() => { + const isJsonObject = type === InputVarType.jsonObject + if (!isJsonObject || !tempPayload.json_schema) + return '' + try { + return JSON.stringify(JSON.parse(tempPayload.json_schema).properties, null, 2) + } + catch (_e) { + return '' + } + }, [tempPayload.json_schema]) useEffect(() => { // To fix the first input element auto focus, then directly close modal will raise error if (isShow) @@ -82,25 +101,74 @@ const ConfigModal: FC = ({ } }, []) - const handleTypeChange = useCallback((type: InputVarType) => { - return () => { - const newPayload = produce(tempPayload, (draft) => { - draft.type = type - // Clear default value when switching types - draft.default = undefined - if ([InputVarType.singleFile, InputVarType.multiFiles].includes(type)) { - (Object.keys(DEFAULT_FILE_UPLOAD_SETTING)).forEach((key) => { - if (key !== 'max_length') - (draft as any)[key] = (DEFAULT_FILE_UPLOAD_SETTING as any)[key] - }) - if (type === InputVarType.multiFiles) - draft.max_length = DEFAULT_FILE_UPLOAD_SETTING.max_length - } - if (type === InputVarType.paragraph) - draft.max_length = DEFAULT_VALUE_MAX_LEN - }) - setTempPayload(newPayload) + const handleJSONSchemaChange = useCallback((value: string) => { + try { + const v = JSON.parse(value) + const res = { + ...jsonObjectWrap, + properties: v, + } + handlePayloadChange('json_schema')(JSON.stringify(res, null, 2)) } + catch (_e) { + return null + } + }, [handlePayloadChange]) + + const selectOptions: SelectItem[] = [ + { + name: t('appDebug.variableConfig.text-input'), + value: InputVarType.textInput, + }, + { + name: t('appDebug.variableConfig.paragraph'), + value: InputVarType.paragraph, + }, + { + name: t('appDebug.variableConfig.select'), + value: InputVarType.select, + }, + { + name: t('appDebug.variableConfig.number'), + value: InputVarType.number, + }, + { + name: t('appDebug.variableConfig.checkbox'), + value: InputVarType.checkbox, + }, + ...(supportFile ? [ + { + name: t('appDebug.variableConfig.single-file'), + value: InputVarType.singleFile, + }, + { + name: t('appDebug.variableConfig.multi-files'), + value: InputVarType.multiFiles, + }, + ] : []), + ...((!isBasicApp && isSupportJSON) ? [{ + name: t('appDebug.variableConfig.json'), + value: InputVarType.jsonObject, + }] : []), + ] + + const handleTypeChange = useCallback((item: SelectItem) => { + const type = item.value as InputVarType + + const newPayload = produce(tempPayload, (draft) => { + draft.type = type + if ([InputVarType.singleFile, InputVarType.multiFiles].includes(type)) { + (Object.keys(DEFAULT_FILE_UPLOAD_SETTING)).forEach((key) => { + if (key !== 'max_length') + (draft as any)[key] = (DEFAULT_FILE_UPLOAD_SETTING as any)[key] + }) + if (type === InputVarType.multiFiles) + draft.max_length = DEFAULT_FILE_UPLOAD_SETTING.max_length + } + if (type === InputVarType.paragraph) + draft.max_length = DEFAULT_VALUE_MAX_LEN + }) + setTempPayload(newPayload) }, [tempPayload]) const handleVarKeyBlur = useCallback((e: any) => { @@ -142,15 +210,6 @@ const ConfigModal: FC = ({ if (!isVariableNameValid) return - // TODO: check if key already exists. should the consider the edit case - // if (varKeys.map(key => key?.trim()).includes(tempPayload.variable.trim())) { - // Toast.notify({ - // type: 'error', - // message: t('appDebug.varKeyError.keyAlreadyExists', { key: tempPayload.variable }), - // }) - // return - // } - if (!tempPayload.label) { Toast.notify({ type: 'error', message: t('appDebug.variableConfig.errorMsg.labelNameRequired') }) return @@ -204,18 +263,8 @@ const ConfigModal: FC = ({ >
- -
- - - - - {supportFile && <> - - - } -
+
@@ -330,6 +379,21 @@ const ConfigModal: FC = ({ )} + {type === InputVarType.jsonObject && ( + + {jsonConfigPlaceHolder}
+ } + /> + + )} +
handlePayloadChange('required')(!tempPayload.required)} /> {t('appDebug.variableConfig.required')} diff --git a/web/app/components/app/configuration/config-var/config-modal/type-select.tsx b/web/app/components/app/configuration/config-var/config-modal/type-select.tsx new file mode 100644 index 000000000..3f6a01ed7 --- /dev/null +++ b/web/app/components/app/configuration/config-var/config-modal/type-select.tsx @@ -0,0 +1,97 @@ +'use client' +import type { FC } from 'react' +import React, { useState } from 'react' +import { ChevronDownIcon } from '@heroicons/react/20/solid' +import classNames from '@/utils/classnames' +import { + PortalToFollowElem, + PortalToFollowElemContent, + PortalToFollowElemTrigger, +} from '@/app/components/base/portal-to-follow-elem' +import InputVarTypeIcon from '@/app/components/workflow/nodes/_base/components/input-var-type-icon' +import type { InputVarType } from '@/app/components/workflow/types' +import cn from '@/utils/classnames' +import Badge from '@/app/components/base/badge' +import { inputVarTypeToVarType } from '@/app/components/workflow/nodes/_base/components/variable/utils' + +export type Item = { + value: InputVarType + name: string +} + +type Props = { + value: string | number + onSelect: (value: Item) => void + items: Item[] + popupClassName?: string + popupInnerClassName?: string + readonly?: boolean + hideChecked?: boolean +} +const TypeSelector: FC = ({ + value, + onSelect, + items, + popupInnerClassName, + readonly, +}) => { + const [open, setOpen] = useState(false) + const selectedItem = value ? items.find(item => item.value === value) : undefined + + return ( + + !readonly && setOpen(v => !v)} className='w-full'> +
+
+ + + {selectedItem?.name} + +
+
+ {inputVarTypeToVarType(selectedItem?.value as InputVarType)} + +
+
+ +
+ +
+ {items.map((item: Item) => ( +
{ + onSelect(item) + setOpen(false) + }} + > +
+ + {item.name} +
+ {inputVarTypeToVarType(item.value)} +
+ ))} +
+
+
+ ) +} + +export default TypeSelector diff --git a/web/app/components/app/configuration/config-var/index.tsx b/web/app/components/app/configuration/config-var/index.tsx index 612d47603..2ac68227e 100644 --- a/web/app/components/app/configuration/config-var/index.tsx +++ b/web/app/components/app/configuration/config-var/index.tsx @@ -12,7 +12,7 @@ import SelectVarType from './select-var-type' import Tooltip from '@/app/components/base/tooltip' import type { PromptVariable } from '@/models/debug' import { DEFAULT_VALUE_MAX_LEN } from '@/config' -import { getNewVar } from '@/utils/var' +import { getNewVar, hasDuplicateStr } from '@/utils/var' import Toast from '@/app/components/base/toast' import Confirm from '@/app/components/base/confirm' import ConfigContext from '@/context/debug-configuration' @@ -80,7 +80,28 @@ const ConfigVar: FC = ({ promptVariables, readonly, onPromptVar delete draft[currIndex].options }) + const newList = newPromptVariables + let errorMsgKey = '' + let typeName = '' + if (hasDuplicateStr(newList.map(item => item.key))) { + errorMsgKey = 'appDebug.varKeyError.keyAlreadyExists' + typeName = 'appDebug.variableConfig.varName' + } + else if (hasDuplicateStr(newList.map(item => item.name as string))) { + errorMsgKey = 'appDebug.varKeyError.keyAlreadyExists' + typeName = 'appDebug.variableConfig.labelName' + } + + if (errorMsgKey) { + Toast.notify({ + type: 'error', + message: t(errorMsgKey, { key: t(typeName) }), + }) + return false + } + onPromptVariablesChange?.(newPromptVariables) + return true } const { setShowExternalDataToolModal } = useModalContext() @@ -190,7 +211,7 @@ const ConfigVar: FC = ({ promptVariables, readonly, onPromptVar const handleConfig = ({ key, type, index, name, config, icon, icon_background }: ExternalDataToolParams) => { // setCurrKey(key) setCurrIndex(index) - if (type !== 'string' && type !== 'paragraph' && type !== 'select' && type !== 'number') { + if (type !== 'string' && type !== 'paragraph' && type !== 'select' && type !== 'number' && type !== 'checkbox') { handleOpenExternalDataToolModal({ key, type, index, name, config, icon, icon_background }, promptVariables) return } @@ -245,7 +266,8 @@ const ConfigVar: FC = ({ promptVariables, readonly, onPromptVar isShow={isShowEditModal} onClose={hideEditModal} onConfirm={(item) => { - updatePromptVariableItem(item) + const isValid = updatePromptVariableItem(item) + if (!isValid) return hideEditModal() }} varKeys={promptVariables.map(v => v.key)} diff --git a/web/app/components/app/configuration/config-var/select-var-type.tsx b/web/app/components/app/configuration/config-var/select-var-type.tsx index ce5a5fccf..2977f05d9 100644 --- a/web/app/components/app/configuration/config-var/select-var-type.tsx +++ b/web/app/components/app/configuration/config-var/select-var-type.tsx @@ -65,6 +65,7 @@ const SelectVarType: FC = ({ +
diff --git a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx index dad5441a5..62bd57c5d 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx @@ -120,6 +120,8 @@ const SettingBuiltInTool: FC = ({ return t('tools.setBuiltInTools.number') if (type === 'text-input') return t('tools.setBuiltInTools.string') + if (type === 'checkbox') + return 'boolean' if (type === 'file') return t('tools.setBuiltInTools.file') return type diff --git a/web/app/components/app/configuration/debug/chat-user-input.tsx b/web/app/components/app/configuration/debug/chat-user-input.tsx index fb4ac31d9..ac07691ce 100644 --- a/web/app/components/app/configuration/debug/chat-user-input.tsx +++ b/web/app/components/app/configuration/debug/chat-user-input.tsx @@ -8,6 +8,7 @@ import Textarea from '@/app/components/base/textarea' import { DEFAULT_VALUE_MAX_LEN } from '@/config' import type { Inputs } from '@/models/debug' import cn from '@/utils/classnames' +import BoolInput from '@/app/components/workflow/nodes/_base/components/before-run-form/bool-input' type Props = { inputs: Inputs @@ -31,7 +32,7 @@ const ChatUserInput = ({ return obj })() - const handleInputValueChange = (key: string, value: string) => { + const handleInputValueChange = (key: string, value: string | boolean) => { if (!(key in promptVariableObj)) return @@ -55,10 +56,12 @@ const ChatUserInput = ({ className='mb-4 last-of-type:mb-0' >
+ {type !== 'checkbox' && (
{name || key}
{!required && {t('workflow.panel.optional')}}
+ )}
{type === 'string' && ( )} + {type === 'checkbox' && ( + { handleInputValueChange(key, value) }} + /> + )}
diff --git a/web/app/components/app/configuration/debug/index.tsx b/web/app/components/app/configuration/debug/index.tsx index 38b0c890e..9a50d1b87 100644 --- a/web/app/components/app/configuration/debug/index.tsx +++ b/web/app/components/app/configuration/debug/index.tsx @@ -34,7 +34,7 @@ import { RefreshCcw01 } from '@/app/components/base/icons/src/vender/line/arrows import TooltipPlus from '@/app/components/base/tooltip' import ActionButton, { ActionButtonState } from '@/app/components/base/action-button' import type { ModelConfig as BackendModelConfig, VisionFile, VisionSettings } from '@/types/app' -import { promptVariablesToUserInputsForm } from '@/utils/model-config' +import { formatBooleanInputs, promptVariablesToUserInputsForm } from '@/utils/model-config' import TextGeneration from '@/app/components/app/text-generate/item' import { IS_CE_EDITION } from '@/config' import type { Inputs } from '@/models/debug' @@ -259,7 +259,7 @@ const Debug: FC = ({ } const data: Record = { - inputs, + inputs: formatBooleanInputs(modelConfig.configs.prompt_variables, inputs), model_config: postModelConfig, } diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index 42affb055..512f57bcc 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -60,7 +60,6 @@ import { useModelListAndDefaultModelAndCurrentProviderAndModel, useTextGenerationCurrentProviderAndModelAndModelList, } from '@/app/components/header/account-setting/model-provider-page/hooks' -import { fetchCollectionList } from '@/service/tools' import type { Collection } from '@/app/components/tools/types' import { useStore as useAppStore } from '@/app/components/app/store' import { @@ -82,6 +81,7 @@ import { supportFunctionCall } from '@/utils/tool-call' import { MittProvider } from '@/context/mitt-context' import { fetchAndMergeValidCompletionParams } from '@/utils/completion-params' import Toast from '@/app/components/base/toast' +import { fetchCollectionList } from '@/service/tools' import { useAppContext } from '@/context/app-context' type PublishConfig = { diff --git a/web/app/components/app/configuration/prompt-value-panel/index.tsx b/web/app/components/app/configuration/prompt-value-panel/index.tsx index b36bf8848..e88268ba4 100644 --- a/web/app/components/app/configuration/prompt-value-panel/index.tsx +++ b/web/app/components/app/configuration/prompt-value-panel/index.tsx @@ -22,6 +22,7 @@ import type { VisionFile, VisionSettings } from '@/types/app' import { DEFAULT_VALUE_MAX_LEN } from '@/config' import { useStore as useAppStore } from '@/app/components/app/store' import cn from '@/utils/classnames' +import BoolInput from '@/app/components/workflow/nodes/_base/components/before-run-form/bool-input' export type IPromptValuePanelProps = { appType: AppType @@ -66,7 +67,7 @@ const PromptValuePanel: FC = ({ else { return !modelConfig.configs.prompt_template } }, [chatPromptConfig.prompt, completionPromptConfig.prompt?.text, isAdvancedMode, mode, modelConfig.configs.prompt_template, modelModeType]) - const handleInputValueChange = (key: string, value: string) => { + const handleInputValueChange = (key: string, value: string | boolean) => { if (!(key in promptVariableObj)) return @@ -109,10 +110,12 @@ const PromptValuePanel: FC = ({ className='mb-4 last-of-type:mb-0' >
-
-
{name || key}
- {!required && {t('workflow.panel.optional')}} -
+ {type !== 'checkbox' && ( +
+
{name || key}
+ {!required && {t('workflow.panel.optional')}} +
+ )}
{type === 'string' && ( = ({ maxLength={max_length || DEFAULT_VALUE_MAX_LEN} /> )} + {type === 'checkbox' && ( + { handleInputValueChange(key, value) }} + /> + )}
diff --git a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx index f3768e80c..e856e6a88 100644 --- a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx +++ b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx @@ -23,6 +23,7 @@ import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested import { Markdown } from '@/app/components/base/markdown' import cn from '@/utils/classnames' import type { FileEntity } from '../../file-uploader/types' +import { formatBooleanInputs } from '@/utils/model-config' import Avatar from '../../avatar' const ChatWrapper = () => { @@ -89,7 +90,7 @@ const ChatWrapper = () => { let hasEmptyInput = '' let fileIsUploading = false - const requiredVars = inputsForms.filter(({ required }) => required) + const requiredVars = inputsForms.filter(({ required, type }) => required && type !== InputVarType.checkbox) if (requiredVars.length) { requiredVars.forEach(({ variable, label, type }) => { if (hasEmptyInput) @@ -131,7 +132,7 @@ const ChatWrapper = () => { const data: any = { query: message, files, - inputs: currentConversationId ? currentConversationInputs : newConversationInputs, + inputs: formatBooleanInputs(inputsForms, currentConversationId ? currentConversationInputs : newConversationInputs), conversation_id: currentConversationId, parent_message_id: (isRegenerate ? parentAnswer?.id : getLastAnswer(chatList)?.id) || null, } diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index 5a2919fe5..714e38b21 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -222,6 +222,14 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { type: 'number', } } + + if(item.checkbox) { + return { + ...item.checkbox, + default: false, + type: 'checkbox', + } + } if (item.select) { const isInputInOptions = item.select.options.includes(initInputs[item.select.variable]) return { @@ -245,6 +253,13 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { } } + if (item.json_object) { + return { + ...item.json_object, + type: 'json_object', + } + } + let value = initInputs[item['text-input'].variable] if (value && item['text-input'].max_length && value.length > item['text-input'].max_length) value = value.slice(0, item['text-input'].max_length) @@ -340,7 +355,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { let hasEmptyInput = '' let fileIsUploading = false - const requiredVars = inputsForms.filter(({ required }) => required) + const requiredVars = inputsForms.filter(({ required, type }) => required && type !== InputVarType.checkbox) if (requiredVars.length) { requiredVars.forEach(({ variable, label, type }) => { if (hasEmptyInput) diff --git a/web/app/components/base/chat/chat-with-history/inputs-form/content.tsx b/web/app/components/base/chat/chat-with-history/inputs-form/content.tsx index 3304d50a5..392bdf2b7 100644 --- a/web/app/components/base/chat/chat-with-history/inputs-form/content.tsx +++ b/web/app/components/base/chat/chat-with-history/inputs-form/content.tsx @@ -6,6 +6,9 @@ import Textarea from '@/app/components/base/textarea' import { PortalSelect } from '@/app/components/base/select' import { FileUploaderInAttachmentWrapper } from '@/app/components/base/file-uploader' import { InputVarType } from '@/app/components/workflow/types' +import BoolInput from '@/app/components/workflow/nodes/_base/components/before-run-form/bool-input' +import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' +import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' type Props = { showTip?: boolean @@ -42,12 +45,14 @@ const InputsFormContent = ({ showTip }: Props) => {
{visibleInputsForms.map(form => (
-
-
{form.label}
- {!form.required && ( -
{t('appDebug.variableTable.optional')}
- )} -
+ {form.type !== InputVarType.checkbox && ( +
+
{form.label}
+ {!form.required && ( +
{t('appDebug.variableTable.optional')}
+ )} +
+ )} {form.type === InputVarType.textInput && ( { placeholder={form.label} /> )} + {form.type === InputVarType.checkbox && ( + handleFormChange(form.variable, value)} + /> + )} {form.type === InputVarType.select && ( { }} /> )} + {form.type === InputVarType.jsonObject && ( + handleFormChange(form.variable, v)} + noWrapper + className='bg h-[80px] overflow-y-auto rounded-[10px] bg-components-input-bg-normal p-1' + placeholder={ +
{form.json_schema}
+ } + /> + )}
))} {showTip && ( diff --git a/web/app/components/base/chat/chat/check-input-forms-hooks.ts b/web/app/components/base/chat/chat/check-input-forms-hooks.ts index 62c59a06f..469e21002 100644 --- a/web/app/components/base/chat/chat/check-input-forms-hooks.ts +++ b/web/app/components/base/chat/chat/check-input-forms-hooks.ts @@ -12,7 +12,7 @@ export const useCheckInputsForms = () => { const checkInputsForm = useCallback((inputs: Record, inputsForm: InputForm[]) => { let hasEmptyInput = '' let fileIsUploading = false - const requiredVars = inputsForm.filter(({ required }) => required) + const requiredVars = inputsForm.filter(({ required, type }) => required && type !== InputVarType.checkbox) // boolean can be not checked if (requiredVars?.length) { requiredVars.forEach(({ variable, label, type }) => { diff --git a/web/app/components/base/chat/chat/utils.ts b/web/app/components/base/chat/chat/utils.ts index 69bc68077..199ccff57 100644 --- a/web/app/components/base/chat/chat/utils.ts +++ b/web/app/components/base/chat/chat/utils.ts @@ -31,6 +31,12 @@ export const getProcessedInputs = (inputs: Record, inputsForm: Inpu inputsForm.forEach((item) => { const inputValue = inputs[item.variable] + // set boolean type default value + if(item.type === InputVarType.checkbox) { + processedInputs[item.variable] = !!inputValue + return + } + if (!inputValue) return diff --git a/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx b/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx index 8429c82e0..14a291e9f 100644 --- a/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx +++ b/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx @@ -90,7 +90,7 @@ const ChatWrapper = () => { let hasEmptyInput = '' let fileIsUploading = false - const requiredVars = inputsForms.filter(({ required }) => required) + const requiredVars = inputsForms.filter(({ required, type }) => required && type !== InputVarType.checkbox) // boolean can be not checked if (requiredVars.length) { requiredVars.forEach(({ variable, label, type }) => { if (hasEmptyInput) diff --git a/web/app/components/base/chat/embedded-chatbot/hooks.tsx b/web/app/components/base/chat/embedded-chatbot/hooks.tsx index 3281f05f7..f0e63abc7 100644 --- a/web/app/components/base/chat/embedded-chatbot/hooks.tsx +++ b/web/app/components/base/chat/embedded-chatbot/hooks.tsx @@ -195,6 +195,13 @@ export const useEmbeddedChatbot = () => { type: 'number', } } + if (item.checkbox) { + return { + ...item.checkbox, + default: false, + type: 'checkbox', + } + } if (item.select) { const isInputInOptions = item.select.options.includes(initInputs[item.select.variable]) return { @@ -218,6 +225,13 @@ export const useEmbeddedChatbot = () => { } } + if (item.json_object) { + return { + ...item.json_object, + type: 'json_object', + } + } + let value = initInputs[item['text-input'].variable] if (value && item['text-input'].max_length && value.length > item['text-input'].max_length) value = value.slice(0, item['text-input'].max_length) @@ -312,7 +326,7 @@ export const useEmbeddedChatbot = () => { let hasEmptyInput = '' let fileIsUploading = false - const requiredVars = inputsForms.filter(({ required }) => required) + const requiredVars = inputsForms.filter(({ required, type }) => required && type !== InputVarType.checkbox) if (requiredVars.length) { requiredVars.forEach(({ variable, label, type }) => { if (hasEmptyInput) diff --git a/web/app/components/base/chat/embedded-chatbot/inputs-form/content.tsx b/web/app/components/base/chat/embedded-chatbot/inputs-form/content.tsx index 29fa5394e..1235899d1 100644 --- a/web/app/components/base/chat/embedded-chatbot/inputs-form/content.tsx +++ b/web/app/components/base/chat/embedded-chatbot/inputs-form/content.tsx @@ -6,6 +6,9 @@ import Textarea from '@/app/components/base/textarea' import { PortalSelect } from '@/app/components/base/select' import { FileUploaderInAttachmentWrapper } from '@/app/components/base/file-uploader' import { InputVarType } from '@/app/components/workflow/types' +import BoolInput from '@/app/components/workflow/nodes/_base/components/before-run-form/bool-input' +import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' +import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' type Props = { showTip?: boolean @@ -42,12 +45,14 @@ const InputsFormContent = ({ showTip }: Props) => {
{visibleInputsForms.map(form => (
+ {form.type !== InputVarType.checkbox && (
{form.label}
{!form.required && (
{t('appDebug.variableTable.optional')}
)}
+ )} {form.type === InputVarType.textInput && ( { placeholder={form.label} /> )} + {form.type === InputVarType.checkbox && ( + handleFormChange(form.variable, value)} + /> + )} {form.type === InputVarType.select && ( { }} /> )} + {form.type === InputVarType.jsonObject && ( + handleFormChange(form.variable, v)} + noWrapper + className='bg h-[80px] overflow-y-auto rounded-[10px] bg-components-input-bg-normal p-1' + placeholder={ +
{form.json_schema}
+ } + /> + )}
))} {showTip && ( diff --git a/web/app/components/base/form/types.ts b/web/app/components/base/form/types.ts index 5c8e36126..f4437948f 100644 --- a/web/app/components/base/form/types.ts +++ b/web/app/components/base/form/types.ts @@ -24,7 +24,7 @@ export enum FormTypeEnum { secretInput = 'secret-input', select = 'select', radio = 'radio', - boolean = 'boolean', + checkbox = 'checkbox', files = 'files', file = 'file', modelSelector = 'model-selector', diff --git a/web/app/components/base/prompt-editor/plugins/current-block/current-block-replacement-block.tsx b/web/app/components/base/prompt-editor/plugins/current-block/current-block-replacement-block.tsx index 8ca56b0cf..aa5636036 100644 --- a/web/app/components/base/prompt-editor/plugins/current-block/current-block-replacement-block.tsx +++ b/web/app/components/base/prompt-editor/plugins/current-block/current-block-replacement-block.tsx @@ -53,7 +53,6 @@ const CurrentBlockReplacementBlock = ({ return mergeRegister( editor.registerNodeTransform(CustomTextNode, textNode => decoratorTransform(textNode, getMatch, createCurrentBlockNode)), ) - // eslint-disable-next-line react-hooks/exhaustive-deps }, []) return null diff --git a/web/app/components/base/prompt-editor/plugins/error-message-block/error-message-block-replacement-block.tsx b/web/app/components/base/prompt-editor/plugins/error-message-block/error-message-block-replacement-block.tsx index 80c89c732..cd8df107f 100644 --- a/web/app/components/base/prompt-editor/plugins/error-message-block/error-message-block-replacement-block.tsx +++ b/web/app/components/base/prompt-editor/plugins/error-message-block/error-message-block-replacement-block.tsx @@ -52,7 +52,6 @@ const ErrorMessageBlockReplacementBlock = ({ return mergeRegister( editor.registerNodeTransform(CustomTextNode, textNode => decoratorTransform(textNode, getMatch, createErrorMessageBlockNode)), ) - // eslint-disable-next-line react-hooks/exhaustive-deps }, []) return null diff --git a/web/app/components/base/prompt-editor/plugins/last-run-block/last-run-block-replacement-block.tsx b/web/app/components/base/prompt-editor/plugins/last-run-block/last-run-block-replacement-block.tsx index 9d2882801..2e5f92e2a 100644 --- a/web/app/components/base/prompt-editor/plugins/last-run-block/last-run-block-replacement-block.tsx +++ b/web/app/components/base/prompt-editor/plugins/last-run-block/last-run-block-replacement-block.tsx @@ -52,7 +52,6 @@ const LastRunReplacementBlock = ({ return mergeRegister( editor.registerNodeTransform(CustomTextNode, textNode => decoratorTransform(textNode, getMatch, createLastRunBlockNode)), ) - // eslint-disable-next-line react-hooks/exhaustive-deps }, []) return null diff --git a/web/app/components/plugins/plugin-detail-panel/app-selector/app-inputs-panel.tsx b/web/app/components/plugins/plugin-detail-panel/app-selector/app-inputs-panel.tsx index d3ac9d7d2..12cd74e10 100644 --- a/web/app/components/plugins/plugin-detail-panel/app-selector/app-inputs-panel.tsx +++ b/web/app/components/plugins/plugin-detail-panel/app-selector/app-inputs-panel.tsx @@ -77,6 +77,13 @@ const AppInputsPanel = ({ required: false, } } + if(item.checkbox) { + return { + ...item.checkbox, + type: 'checkbox', + required: false, + } + } if (item.select) { return { ...item.select, @@ -103,6 +110,13 @@ const AppInputsPanel = ({ } } + if (item.json_object) { + return { + ...item.json_object, + type: 'json_object', + } + } + return { ...item['text-input'], type: 'text-input', diff --git a/web/app/components/plugins/plugin-detail-panel/strategy-detail.tsx b/web/app/components/plugins/plugin-detail-panel/strategy-detail.tsx index 3427587fd..b286d57dc 100644 --- a/web/app/components/plugins/plugin-detail-panel/strategy-detail.tsx +++ b/web/app/components/plugins/plugin-detail-panel/strategy-detail.tsx @@ -63,6 +63,8 @@ const StrategyDetail: FC = ({ return t('tools.setBuiltInTools.number') if (type === 'text-input') return t('tools.setBuiltInTools.string') + if (type === 'checkbox') + return 'boolean' if (type === 'file') return t('tools.setBuiltInTools.file') if (type === 'array[tools]') diff --git a/web/app/components/share/text-generation/result/index.tsx b/web/app/components/share/text-generation/result/index.tsx index 97a3a7739..fc5422589 100644 --- a/web/app/components/share/text-generation/result/index.tsx +++ b/web/app/components/share/text-generation/result/index.tsx @@ -21,6 +21,7 @@ import { TEXT_GENERATION_TIMEOUT_MS } from '@/config' import { getFilesInLogs, } from '@/app/components/base/file-uploader/utils' +import { formatBooleanInputs } from '@/utils/model-config' export type IResultProps = { isWorkflow: boolean @@ -124,7 +125,9 @@ const Result: FC = ({ } let hasEmptyInput = '' - const requiredVars = prompt_variables?.filter(({ key, name, required }) => { + const requiredVars = prompt_variables?.filter(({ key, name, required, type }) => { + if(type === 'boolean') + return false // boolean input is not required const res = (!key || !key.trim()) || (!name || !name.trim()) || (required || required === undefined || required === null) return res }) || [] // compatible with old version @@ -158,7 +161,7 @@ const Result: FC = ({ return const data: Record = { - inputs, + inputs: formatBooleanInputs(promptConfig?.prompt_variables, inputs), } if (visionConfig.enabled && completionFiles && completionFiles?.length > 0) { data.files = completionFiles.map((item) => { diff --git a/web/app/components/share/text-generation/run-once/index.tsx b/web/app/components/share/text-generation/run-once/index.tsx index 7622daa86..bae7a1d16 100644 --- a/web/app/components/share/text-generation/run-once/index.tsx +++ b/web/app/components/share/text-generation/run-once/index.tsx @@ -18,6 +18,9 @@ import { FileUploaderInAttachmentWrapper } from '@/app/components/base/file-uplo import { getProcessedFiles } from '@/app/components/base/file-uploader/utils' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import cn from '@/utils/classnames' +import BoolInput from '@/app/components/workflow/nodes/_base/components/before-run-form/bool-input' +import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' +import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' export type IRunOnceProps = { siteInfo: SiteInfo @@ -93,7 +96,9 @@ const RunOnce: FC = ({ {(inputs === null || inputs === undefined || Object.keys(inputs).length === 0) || !isInitialized ? null : promptConfig.prompt_variables.map(item => (
- + {item.type !== 'boolean' && ( + + )}
{item.type === 'select' && ( = ({
= { }, checkValid(payload: ListFilterNodeType, t: any) { let errorMessages = '' - const { variable, var_type, filter_by } = payload + const { variable, var_type, filter_by, item_var_type } = payload if (!errorMessages && !variable?.length) errorMessages = t(`${i18nPrefix}.fieldRequired`, { field: t('workflow.nodes.listFilter.inputVar') }) @@ -51,7 +51,7 @@ const nodeDefault: NodeDefault = { if (!errorMessages && !filter_by.conditions[0]?.comparison_operator) errorMessages = t(`${i18nPrefix}.fieldRequired`, { field: t('workflow.nodes.listFilter.filterConditionComparisonOperator') }) - if (!errorMessages && !comparisonOperatorNotRequireValue(filter_by.conditions[0]?.comparison_operator) && !filter_by.conditions[0]?.value) + if (!errorMessages && !comparisonOperatorNotRequireValue(filter_by.conditions[0]?.comparison_operator) && (item_var_type === VarType.boolean ? !filter_by.conditions[0]?.value === undefined : !filter_by.conditions[0]?.value)) errorMessages = t(`${i18nPrefix}.fieldRequired`, { field: t('workflow.nodes.listFilter.filterConditionComparisonValue') }) } diff --git a/web/app/components/workflow/nodes/list-operator/types.ts b/web/app/components/workflow/nodes/list-operator/types.ts index 770590329..44203cd0f 100644 --- a/web/app/components/workflow/nodes/list-operator/types.ts +++ b/web/app/components/workflow/nodes/list-operator/types.ts @@ -14,7 +14,7 @@ export type Limit = { export type Condition = { key: string comparison_operator: ComparisonOperator - value: string | number | string[] + value: string | number | boolean | string[] } export type ListFilterNodeType = CommonNodeType & { diff --git a/web/app/components/workflow/nodes/list-operator/use-config.ts b/web/app/components/workflow/nodes/list-operator/use-config.ts index 21e976172..d53a0a6c3 100644 --- a/web/app/components/workflow/nodes/list-operator/use-config.ts +++ b/web/app/components/workflow/nodes/list-operator/use-config.ts @@ -45,7 +45,7 @@ const useConfig = (id: string, payload: ListFilterNodeType) => { isChatMode, isConstant: false, }) - let itemVarType = varType + let itemVarType switch (varType) { case VarType.arrayNumber: itemVarType = VarType.number @@ -59,6 +59,11 @@ const useConfig = (id: string, payload: ListFilterNodeType) => { case VarType.arrayObject: itemVarType = VarType.object break + case VarType.arrayBoolean: + itemVarType = VarType.boolean + break + default: + itemVarType = varType } return { varType, itemVarType } }, [availableNodes, getCurrentVariableType, inputs.variable, isChatMode, isInIteration, iterationNode, loopNode]) @@ -84,7 +89,7 @@ const useConfig = (id: string, payload: ListFilterNodeType) => { draft.filter_by.conditions = [{ key: (isFileArray && !draft.filter_by.conditions[0]?.key) ? 'name' : '', comparison_operator: getOperators(itemVarType, isFileArray ? { key: 'name' } : undefined)[0], - value: '', + value: itemVarType === VarType.boolean ? false : '', }] if (isFileArray && draft.order_by.enabled && !draft.order_by.key) draft.order_by.key = 'name' @@ -94,7 +99,7 @@ const useConfig = (id: string, payload: ListFilterNodeType) => { const filterVar = useCallback((varPayload: Var) => { // Don't know the item struct of VarType.arrayObject, so not support it - return [VarType.arrayNumber, VarType.arrayString, VarType.arrayFile].includes(varPayload.type) + return [VarType.arrayNumber, VarType.arrayString, VarType.arrayBoolean, VarType.arrayFile].includes(varPayload.type) }, []) const handleFilterEnabledChange = useCallback((enabled: boolean) => { diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-config.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-config.tsx index fecd1093d..b87dc6e24 100644 --- a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-config.tsx +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-config.tsx @@ -11,7 +11,6 @@ import VisualEditor from './visual-editor' import SchemaEditor from './schema-editor' import { checkJsonSchemaDepth, - convertBooleanToString, getValidationErrorMessage, jsonToSchema, preValidateSchema, @@ -87,7 +86,6 @@ const JsonSchemaConfig: FC = ({ setValidationError(`Schema exceeds maximum depth of ${JSON_SCHEMA_MAX_DEPTH}.`) return } - convertBooleanToString(schema) const validationErrors = validateSchemaAgainstDraft7(schema) if (validationErrors.length > 0) { setValidationError(getValidationErrorMessage(validationErrors)) @@ -168,7 +166,6 @@ const JsonSchemaConfig: FC = ({ setValidationError(`Schema exceeds maximum depth of ${JSON_SCHEMA_MAX_DEPTH}.`) return } - convertBooleanToString(schema) const validationErrors = validateSchemaAgainstDraft7(schema) if (validationErrors.length > 0) { setValidationError(getValidationErrorMessage(validationErrors)) diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/index.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/index.tsx index 4023a937f..ae72d494d 100644 --- a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/index.tsx +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/index.tsx @@ -39,21 +39,19 @@ type EditCardProps = { const TYPE_OPTIONS = [ { value: Type.string, text: 'string' }, { value: Type.number, text: 'number' }, - // { value: Type.boolean, text: 'boolean' }, + { value: Type.boolean, text: 'boolean' }, { value: Type.object, text: 'object' }, { value: ArrayType.string, text: 'array[string]' }, { value: ArrayType.number, text: 'array[number]' }, - // { value: ArrayType.boolean, text: 'array[boolean]' }, { value: ArrayType.object, text: 'array[object]' }, ] const MAXIMUM_DEPTH_TYPE_OPTIONS = [ { value: Type.string, text: 'string' }, { value: Type.number, text: 'number' }, - // { value: Type.boolean, text: 'boolean' }, + { value: Type.boolean, text: 'boolean' }, { value: ArrayType.string, text: 'array[string]' }, { value: ArrayType.number, text: 'array[number]' }, - // { value: ArrayType.boolean, text: 'array[boolean]' }, ] const EditCard: FC = ({ diff --git a/web/app/components/workflow/nodes/llm/utils.ts b/web/app/components/workflow/nodes/llm/utils.ts index fd943d1fa..045acf399 100644 --- a/web/app/components/workflow/nodes/llm/utils.ts +++ b/web/app/components/workflow/nodes/llm/utils.ts @@ -303,6 +303,7 @@ export const getValidationErrorMessage = (errors: ValidationError[]) => { return message } +// Previous Not support boolean type, so transform boolean to string when paste it into schema editor export const convertBooleanToString = (schema: any) => { if (schema.type === Type.boolean) schema.type = Type.string diff --git a/web/app/components/workflow/nodes/loop/components/condition-list/condition-item.tsx b/web/app/components/workflow/nodes/loop/components/condition-list/condition-item.tsx index b3ce67beb..6e573093b 100644 --- a/web/app/components/workflow/nodes/loop/components/condition-list/condition-item.tsx +++ b/web/app/components/workflow/nodes/loop/components/condition-list/condition-item.tsx @@ -36,6 +36,7 @@ import cn from '@/utils/classnames' import { SimpleSelect as Select } from '@/app/components/base/select' import { Variable02 } from '@/app/components/base/icons/src/vender/solid/development' import ConditionVarSelector from './condition-var-selector' +import BoolValue from '@/app/components/workflow/panel/chat-variable-panel/components/bool-value' const optionNameI18NPrefix = 'workflow.nodes.ifElse.optionName' @@ -129,12 +130,12 @@ const ConditionItem = ({ const isArrayValue = fileAttr?.key === 'transfer_method' || fileAttr?.key === 'type' - const handleUpdateConditionValue = useCallback((value: string) => { - if (value === condition.value || (isArrayValue && value === condition.value?.[0])) + const handleUpdateConditionValue = useCallback((value: string | boolean) => { + if (value === condition.value || (isArrayValue && value === (condition.value as string[])?.[0])) return const newCondition = { ...condition, - value: isArrayValue ? [value] : value, + value: isArrayValue ? [value as string] : value, } doUpdateCondition(newCondition) }, [condition, doUpdateCondition, isArrayValue]) @@ -253,7 +254,7 @@ const ConditionItem = ({ />
{ - !comparisonOperatorNotRequireValue(condition.comparison_operator) && !isNotInput && condition.varType !== VarType.number && ( + !comparisonOperatorNotRequireValue(condition.comparison_operator) && !isNotInput && condition.varType !== VarType.number && condition.varType !== VarType.boolean && (
) } + {!comparisonOperatorNotRequireValue(condition.comparison_operator) && condition.varType === VarType.boolean + &&
+ +
+ } { !comparisonOperatorNotRequireValue(condition.comparison_operator) && !isNotInput && condition.varType === VarType.number && (
diff --git a/web/app/components/workflow/nodes/loop/components/loop-variables/form-item.tsx b/web/app/components/workflow/nodes/loop/components/loop-variables/form-item.tsx index 4a05e457b..e4cc13835 100644 --- a/web/app/components/workflow/nodes/loop/components/loop-variables/form-item.tsx +++ b/web/app/components/workflow/nodes/loop/components/loop-variables/form-item.tsx @@ -18,33 +18,16 @@ import { ValueType, VarType, } from '@/app/components/workflow/types' +import BoolValue from '@/app/components/workflow/panel/chat-variable-panel/components/bool-value' -const objectPlaceholder = `# example -# { -# "name": "ray", -# "age": 20 -# }` -const arrayStringPlaceholder = `# example -# [ -# "value1", -# "value2" -# ]` -const arrayNumberPlaceholder = `# example -# [ -# 100, -# 200 -# ]` -const arrayObjectPlaceholder = `# example -# [ -# { -# "name": "ray", -# "age": 20 -# }, -# { -# "name": "lily", -# "age": 18 -# } -# ]` +import { + arrayBoolPlaceholder, + arrayNumberPlaceholder, + arrayObjectPlaceholder, + arrayStringPlaceholder, + objectPlaceholder, +} from '@/app/components/workflow/panel/chat-variable-panel/utils' +import ArrayBoolList from '@/app/components/workflow/panel/chat-variable-panel/components/array-bool-list' type FormItemProps = { nodeId: string @@ -83,6 +66,8 @@ const FormItem = ({ return arrayNumberPlaceholder if (var_type === VarType.arrayObject) return arrayObjectPlaceholder + if (var_type === VarType.arrayBoolean) + return arrayBoolPlaceholder return objectPlaceholder }, [var_type]) @@ -120,6 +105,14 @@ const FormItem = ({ /> ) } + { + value_type === ValueType.constant && var_type === VarType.boolean && ( + + ) + } { value_type === ValueType.constant && (var_type === VarType.object || var_type === VarType.arrayString || var_type === VarType.arrayNumber || var_type === VarType.arrayObject) @@ -137,6 +130,15 @@ const FormItem = ({
) } + { + value_type === ValueType.constant && var_type === VarType.arrayBoolean && ( + + ) + }
) } diff --git a/web/app/components/workflow/nodes/loop/components/loop-variables/item.tsx b/web/app/components/workflow/nodes/loop/components/loop-variables/item.tsx index 42dc34b39..7084389be 100644 --- a/web/app/components/workflow/nodes/loop/components/loop-variables/item.tsx +++ b/web/app/components/workflow/nodes/loop/components/loop-variables/item.tsx @@ -12,6 +12,7 @@ import type { } from '@/app/components/workflow/nodes/loop/types' import { checkKeys, replaceSpaceWithUnderscoreInVarNameInput } from '@/utils/var' import Toast from '@/app/components/base/toast' +import { ValueType, VarType } from '@/app/components/workflow/types' type ItemProps = { item: LoopVariable @@ -42,12 +43,25 @@ const Item = ({ handleUpdateLoopVariable(item.id, { label: e.target.value }) }, [item.id, handleUpdateLoopVariable]) + const getDefaultValue = useCallback((varType: VarType, valueType: ValueType) => { + if(valueType === ValueType.variable) + return undefined + switch (varType) { + case VarType.boolean: + return false + case VarType.arrayBoolean: + return [false] + default: + return undefined + } + }, []) + const handleUpdateItemVarType = useCallback((value: any) => { - handleUpdateLoopVariable(item.id, { var_type: value, value: undefined }) + handleUpdateLoopVariable(item.id, { var_type: value, value: getDefaultValue(value, item.value_type) }) }, [item.id, handleUpdateLoopVariable]) const handleUpdateItemValueType = useCallback((value: any) => { - handleUpdateLoopVariable(item.id, { value_type: value, value: undefined }) + handleUpdateLoopVariable(item.id, { value_type: value, value: getDefaultValue(item.var_type, value) }) }, [item.id, handleUpdateLoopVariable]) const handleUpdateItemValue = useCallback((value: any) => { diff --git a/web/app/components/workflow/nodes/loop/components/loop-variables/variable-type-select.tsx b/web/app/components/workflow/nodes/loop/components/loop-variables/variable-type-select.tsx index 5271660fc..78a995d57 100644 --- a/web/app/components/workflow/nodes/loop/components/loop-variables/variable-type-select.tsx +++ b/web/app/components/workflow/nodes/loop/components/loop-variables/variable-type-select.tsx @@ -22,6 +22,10 @@ const VariableTypeSelect = ({ label: 'Object', value: VarType.object, }, + { + label: 'Boolean', + value: VarType.boolean, + }, { label: 'Array[string]', value: VarType.arrayString, @@ -34,6 +38,10 @@ const VariableTypeSelect = ({ label: 'Array[object]', value: VarType.arrayObject, }, + { + label: 'Array[boolean]', + value: VarType.arrayBoolean, + }, ] return ( diff --git a/web/app/components/workflow/nodes/loop/default.ts b/web/app/components/workflow/nodes/loop/default.ts index b44643245..66ff20b37 100644 --- a/web/app/components/workflow/nodes/loop/default.ts +++ b/web/app/components/workflow/nodes/loop/default.ts @@ -1,4 +1,4 @@ -import { BlockEnum } from '../../types' +import { BlockEnum, VarType } from '../../types' import type { NodeDefault } from '../../types' import { ComparisonOperator, LogicalOperator, type LoopNodeType } from './types' import { isEmptyRelatedOperator } from './utils' @@ -55,7 +55,7 @@ const nodeDefault: NodeDefault = { errorMessages = t(`${i18nPrefix}.fieldRequired`, { field: t(`${i18nPrefix}.fields.variableValue`) }) } else { - if (!isEmptyRelatedOperator(condition.comparison_operator!) && !condition.value) + if (!isEmptyRelatedOperator(condition.comparison_operator!) && (condition.varType === VarType.boolean ? condition.value === undefined : !condition.value)) errorMessages = t(`${i18nPrefix}.fieldRequired`, { field: t(`${i18nPrefix}.fields.variableValue`) }) } } diff --git a/web/app/components/workflow/nodes/loop/types.ts b/web/app/components/workflow/nodes/loop/types.ts index 80c7d51cc..fe23b1f8c 100644 --- a/web/app/components/workflow/nodes/loop/types.ts +++ b/web/app/components/workflow/nodes/loop/types.ts @@ -44,7 +44,7 @@ export type Condition = { variable_selector?: ValueSelector key?: string // sub variable key comparison_operator?: ComparisonOperator - value: string | string[] + value: string | string[] | boolean numberVarType?: NumberVarType sub_variable_condition?: CaseItem } diff --git a/web/app/components/workflow/nodes/loop/use-config.ts b/web/app/components/workflow/nodes/loop/use-config.ts index 4c6e07c9c..87f3d65a9 100644 --- a/web/app/components/workflow/nodes/loop/use-config.ts +++ b/web/app/components/workflow/nodes/loop/use-config.ts @@ -63,7 +63,7 @@ const useConfig = (id: string, payload: LoopNodeType) => { varType: varItem.type, variable_selector: valueSelector, comparison_operator: getOperators(varItem.type, getIsVarFileAttribute(valueSelector) ? { key: valueSelector.slice(-1)[0] } : undefined)[0], - value: '', + value: varItem.type === VarType.boolean ? 'false' : '', }) }) setInputs(newInputs) diff --git a/web/app/components/workflow/nodes/loop/use-single-run-form-params.ts b/web/app/components/workflow/nodes/loop/use-single-run-form-params.ts index 394ab9b16..6a1b6b20f 100644 --- a/web/app/components/workflow/nodes/loop/use-single-run-form-params.ts +++ b/web/app/components/workflow/nodes/loop/use-single-run-form-params.ts @@ -107,7 +107,7 @@ const useSingleRunFormParams = ({ }, [runResult, loopRunResult, t]) const setInputVarValues = useCallback((newPayload: Record) => { - setRunInputData(newPayload) + setRunInputData(newPayload) }, [setRunInputData]) const inputVarValues = (() => { @@ -149,16 +149,15 @@ const useSingleRunFormParams = ({ }) payload.loop_variables?.forEach((loopVariable) => { - if(loopVariable.value_type === ValueType.variable) + if (loopVariable.value_type === ValueType.variable) allInputs.push(loopVariable.value) }) const inputVarsFromValue: InputVar[] = [] const varInputs = [...varSelectorsToVarInputs(allInputs), ...inputVarsFromValue] - const existVarsKey: Record = {} const uniqueVarInputs: InputVar[] = [] varInputs.forEach((input) => { - if(!input) + if (!input) return if (!existVarsKey[input.variable]) { existVarsKey[input.variable] = true @@ -191,7 +190,7 @@ const useSingleRunFormParams = ({ if (condition.variable_selector) vars.push(condition.variable_selector) - if(condition.sub_variable_condition && condition.sub_variable_condition.conditions?.length) + if (condition.sub_variable_condition && condition.sub_variable_condition.conditions?.length) vars.push(...getVarFromCaseItem(condition.sub_variable_condition)) return vars } @@ -203,7 +202,7 @@ const useSingleRunFormParams = ({ vars.push(...conditionVars) }) payload.loop_variables?.forEach((loopVariable) => { - if(loopVariable.value_type === ValueType.variable) + if (loopVariable.value_type === ValueType.variable) vars.push(loopVariable.value) }) const hasFilterLoopVars = vars.filter(item => item[0] !== id) diff --git a/web/app/components/workflow/nodes/loop/utils.ts b/web/app/components/workflow/nodes/loop/utils.ts index 2bc9d8926..bc5e6481c 100644 --- a/web/app/components/workflow/nodes/loop/utils.ts +++ b/web/app/components/workflow/nodes/loop/utils.ts @@ -107,6 +107,13 @@ export const getOperators = (type?: VarType, file?: { key: string }) => { ComparisonOperator.empty, ComparisonOperator.notEmpty, ] + case VarType.boolean: + return [ + ComparisonOperator.is, + ComparisonOperator.isNot, + ComparisonOperator.empty, + ComparisonOperator.notEmpty, + ] case VarType.object: return [ ComparisonOperator.empty, diff --git a/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/update.tsx b/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/update.tsx index 46b3ac381..165ace458 100644 --- a/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/update.tsx +++ b/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/update.tsx @@ -35,7 +35,7 @@ type Props = { onCancel?: () => void } -const TYPES = [ParamType.string, ParamType.number, ParamType.arrayString, ParamType.arrayNumber, ParamType.arrayObject] +const TYPES = [ParamType.string, ParamType.number, ParamType.bool, ParamType.arrayString, ParamType.arrayNumber, ParamType.arrayObject, ParamType.arrayBool] const AddExtractParameter: FC = ({ type, diff --git a/web/app/components/workflow/nodes/parameter-extractor/types.ts b/web/app/components/workflow/nodes/parameter-extractor/types.ts index f5ba717be..49e104154 100644 --- a/web/app/components/workflow/nodes/parameter-extractor/types.ts +++ b/web/app/components/workflow/nodes/parameter-extractor/types.ts @@ -3,11 +3,12 @@ import type { CommonNodeType, Memory, ModelConfig, ValueSelector, VisionSetting export enum ParamType { string = 'string', number = 'number', - bool = 'bool', + bool = 'boolean', select = 'select', arrayString = 'array[string]', arrayNumber = 'array[number]', arrayObject = 'array[object]', + arrayBool = 'array[boolean]', } export type Param = { diff --git a/web/app/components/workflow/nodes/start/components/var-item.tsx b/web/app/components/workflow/nodes/start/components/var-item.tsx index 029547542..e51cd7973 100644 --- a/web/app/components/workflow/nodes/start/components/var-item.tsx +++ b/web/app/components/workflow/nodes/start/components/var-item.tsx @@ -19,7 +19,7 @@ type Props = { className?: string readonly: boolean payload: InputVar - onChange?: (item: InputVar, moreInfo?: MoreInfo) => void + onChange?: (item: InputVar, moreInfo?: MoreInfo) => boolean onRemove?: () => void rightContent?: React.JSX.Element varKeys?: string[] @@ -31,7 +31,7 @@ const VarItem: FC = ({ className, readonly, payload, - onChange = noop, + onChange = () => true, onRemove = noop, rightContent, varKeys = [], @@ -48,7 +48,9 @@ const VarItem: FC = ({ }] = useBoolean(false) const handlePayloadChange = useCallback((payload: InputVar, moreInfo?: MoreInfo) => { - onChange(payload, moreInfo) + const isValid = onChange(payload, moreInfo) + if(!isValid) + return hideEditVarModal() }, [onChange, hideEditVarModal]) return ( diff --git a/web/app/components/workflow/nodes/start/components/var-list.tsx b/web/app/components/workflow/nodes/start/components/var-list.tsx index 024b50a75..bbfeed461 100644 --- a/web/app/components/workflow/nodes/start/components/var-list.tsx +++ b/web/app/components/workflow/nodes/start/components/var-list.tsx @@ -9,6 +9,8 @@ import { v4 as uuid4 } from 'uuid' import { ReactSortable } from 'react-sortablejs' import { RiDraggable } from '@remixicon/react' import cn from '@/utils/classnames' +import { hasDuplicateStr } from '@/utils/var' +import Toast from '@/app/components/base/toast' type Props = { readonly: boolean @@ -28,7 +30,26 @@ const VarList: FC = ({ const newList = produce(list, (draft) => { draft[index] = payload }) + let errorMsgKey = '' + let typeName = '' + if (hasDuplicateStr(newList.map(item => item.variable))) { + errorMsgKey = 'appDebug.varKeyError.keyAlreadyExists' + typeName = 'appDebug.variableConfig.varName' + } + else if (hasDuplicateStr(newList.map(item => item.label as string))) { + errorMsgKey = 'appDebug.varKeyError.keyAlreadyExists' + typeName = 'appDebug.variableConfig.labelName' + } + + if (errorMsgKey) { + Toast.notify({ + type: 'error', + message: t(errorMsgKey, { key: t(typeName) }), + }) + return false + } onChange(newList, moreInfo ? { index, payload: moreInfo } : undefined) + return true } }, [list, onChange]) diff --git a/web/app/components/workflow/nodes/start/panel.tsx b/web/app/components/workflow/nodes/start/panel.tsx index eb04ecb36..0a1efd444 100644 --- a/web/app/components/workflow/nodes/start/panel.tsx +++ b/web/app/components/workflow/nodes/start/panel.tsx @@ -34,7 +34,8 @@ const Panel: FC> = ({ } = useConfig(id, data) const handleAddVarConfirm = (payload: InputVar) => { - handleAddVariable(payload) + const isValid = handleAddVariable(payload) + if (!isValid) return hideAddVarModal() } diff --git a/web/app/components/workflow/nodes/start/use-config.ts b/web/app/components/workflow/nodes/start/use-config.ts index c0ade614e..d67b5f790 100644 --- a/web/app/components/workflow/nodes/start/use-config.ts +++ b/web/app/components/workflow/nodes/start/use-config.ts @@ -11,8 +11,12 @@ import { useWorkflow, } from '@/app/components/workflow/hooks' import useInspectVarsCrud from '../../hooks/use-inspect-vars-crud' +import { hasDuplicateStr } from '@/utils/var' +import Toast from '@/app/components/base/toast' +import { useTranslation } from 'react-i18next' const useConfig = (id: string, payload: StartNodeType) => { + const { t } = useTranslation() const { nodesReadOnly: readOnly } = useNodesReadOnly() const { handleOutVarRenameChange, isVarUsedInNodes, removeUsedVarInNodes } = useWorkflow() const isChatMode = useIsChatMode() @@ -80,7 +84,27 @@ const useConfig = (id: string, payload: StartNodeType) => { const newInputs = produce(inputs, (draft: StartNodeType) => { draft.variables.push(payload) }) + const newList = newInputs.variables + let errorMsgKey = '' + let typeName = '' + if(hasDuplicateStr(newList.map(item => item.variable))) { + errorMsgKey = 'appDebug.varKeyError.keyAlreadyExists' + typeName = 'appDebug.variableConfig.varName' + } + else if(hasDuplicateStr(newList.map(item => item.label as string))) { + errorMsgKey = 'appDebug.varKeyError.keyAlreadyExists' + typeName = 'appDebug.variableConfig.labelName' + } + + if (errorMsgKey) { + Toast.notify({ + type: 'error', + message: t(errorMsgKey, { key: t(typeName) }), + }) + return false + } setInputs(newInputs) + return true }, [inputs, setInputs]) return { readOnly, diff --git a/web/app/components/workflow/nodes/template-transform/use-config.ts b/web/app/components/workflow/nodes/template-transform/use-config.ts index 8be93abdf..fa7eb81ba 100644 --- a/web/app/components/workflow/nodes/template-transform/use-config.ts +++ b/web/app/components/workflow/nodes/template-transform/use-config.ts @@ -65,7 +65,6 @@ const useConfig = (id: string, payload: TemplateTransformNodeType) => { ...defaultConfig, }) } - // eslint-disable-next-line react-hooks/exhaustive-deps }, [defaultConfig]) const handleCodeChange = useCallback((template: string) => { @@ -76,7 +75,7 @@ const useConfig = (id: string, payload: TemplateTransformNodeType) => { }, [setInputs]) const filterVar = useCallback((varPayload: Var) => { - return [VarType.string, VarType.number, VarType.object, VarType.array, VarType.arrayNumber, VarType.arrayString, VarType.arrayObject].includes(varPayload.type) + return [VarType.string, VarType.number, VarType.boolean, VarType.object, VarType.array, VarType.arrayNumber, VarType.arrayString, VarType.arrayBoolean, VarType.arrayObject].includes(varPayload.type) }, []) return { diff --git a/web/app/components/workflow/nodes/variable-assigner/hooks.ts b/web/app/components/workflow/nodes/variable-assigner/hooks.ts index 0e5e10c74..d4e4115a7 100644 --- a/web/app/components/workflow/nodes/variable-assigner/hooks.ts +++ b/web/app/components/workflow/nodes/variable-assigner/hooks.ts @@ -132,7 +132,6 @@ export const useGetAvailableVars = () => { if (!currentNode) return [] - const beforeNodes = getBeforeNodesInSameBranchIncludeParent(nodeId) availableNodes.push(...beforeNodes) const parentNode = nodes.find(node => node.id === currentNode.parentId) @@ -143,7 +142,7 @@ export const useGetAvailableVars = () => { beforeNodes: uniqBy(availableNodes, 'id').filter(node => node.id !== nodeId), isChatMode, hideEnv, - hideChatVar: hideEnv, + hideChatVar: false, filterVar, }) .map(node => ({ diff --git a/web/app/components/workflow/panel/chat-variable-panel/components/array-bool-list.tsx b/web/app/components/workflow/panel/chat-variable-panel/components/array-bool-list.tsx new file mode 100644 index 000000000..5f1dcc229 --- /dev/null +++ b/web/app/components/workflow/panel/chat-variable-panel/components/array-bool-list.tsx @@ -0,0 +1,72 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import { useTranslation } from 'react-i18next' +import { RiAddLine } from '@remixicon/react' +import produce from 'immer' +import RemoveButton from '@/app/components/workflow/nodes/_base/components/remove-button' +import Button from '@/app/components/base/button' +import BoolValue from './bool-value' +import cn from '@/utils/classnames' + +type Props = { + className?: string + list: boolean[] + onChange: (list: boolean[]) => void +} + +const ArrayValueList: FC = ({ + className, + list, + onChange, +}) => { + const { t } = useTranslation() + + const handleChange = useCallback((index: number) => { + return (value: boolean) => { + const newList = produce(list, (draft: any[]) => { + draft[index] = value + }) + onChange(newList) + } + }, [list, onChange]) + + const handleItemRemove = useCallback((index: number) => { + return () => { + const newList = produce(list, (draft) => { + draft.splice(index, 1) + }) + onChange(newList) + } + }, [list, onChange]) + + const handleItemAdd = useCallback(() => { + const newList = produce(list, (draft: any[]) => { + draft.push(false) + }) + onChange(newList) + }, [list, onChange]) + + return ( +
+ {list.map((item, index) => ( +
+ + + +
+ ))} + +
+ ) +} +export default React.memo(ArrayValueList) diff --git a/web/app/components/workflow/panel/chat-variable-panel/components/bool-value.tsx b/web/app/components/workflow/panel/chat-variable-panel/components/bool-value.tsx new file mode 100644 index 000000000..864fefd9a --- /dev/null +++ b/web/app/components/workflow/panel/chat-variable-panel/components/bool-value.tsx @@ -0,0 +1,37 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import OptionCard from '../../../nodes/_base/components/option-card' + +type Props = { + value: boolean + onChange: (value: boolean) => void +} + +const BoolValue: FC = ({ + value, + onChange, +}) => { + const booleanValue = value + const handleChange = useCallback((newValue: boolean) => { + return () => { + onChange(newValue) + } + }, [onChange]) + + return ( +
+ + +
+ ) +} +export default React.memo(BoolValue) diff --git a/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx b/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx index 15292b928..5e476027e 100644 --- a/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx +++ b/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx @@ -16,6 +16,15 @@ import type { ConversationVariable } from '@/app/components/workflow/types' import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' import { ChatVarType } from '@/app/components/workflow/panel/chat-variable-panel/type' import cn from '@/utils/classnames' +import BoolValue from './bool-value' +import ArrayBoolList from './array-bool-list' +import { + arrayBoolPlaceholder, + arrayNumberPlaceholder, + arrayObjectPlaceholder, + arrayStringPlaceholder, + objectPlaceholder, +} from '@/app/components/workflow/panel/chat-variable-panel/utils' import { checkKeys, replaceSpaceWithUnderscoreInVarNameInput } from '@/utils/var' export type ModalPropsType = { @@ -33,39 +42,14 @@ type ObjectValueItem = { const typeList = [ ChatVarType.String, ChatVarType.Number, + ChatVarType.Boolean, ChatVarType.Object, ChatVarType.ArrayString, ChatVarType.ArrayNumber, + ChatVarType.ArrayBoolean, ChatVarType.ArrayObject, ] -const objectPlaceholder = `# example -# { -# "name": "ray", -# "age": 20 -# }` -const arrayStringPlaceholder = `# example -# [ -# "value1", -# "value2" -# ]` -const arrayNumberPlaceholder = `# example -# [ -# 100, -# 200 -# ]` -const arrayObjectPlaceholder = `# example -# [ -# { -# "name": "ray", -# "age": 20 -# }, -# { -# "name": "lily", -# "age": 18 -# } -# ]` - const ChatVariableModal = ({ chatVar, onClose, @@ -94,6 +78,8 @@ const ChatVariableModal = ({ return arrayNumberPlaceholder if (type === ChatVarType.ArrayObject) return arrayObjectPlaceholder + if (type === ChatVarType.ArrayBoolean) + return arrayBoolPlaceholder return objectPlaceholder }, [type]) const getObjectValue = useCallback(() => { @@ -122,12 +108,16 @@ const ChatVariableModal = ({ return value || '' case ChatVarType.Number: return value || 0 + case ChatVarType.Boolean: + return value === undefined ? true : value case ChatVarType.Object: return editInJSON ? value : formatValueFromObject(objectValue) case ChatVarType.ArrayString: case ChatVarType.ArrayNumber: case ChatVarType.ArrayObject: return value?.filter(Boolean) || [] + case ChatVarType.ArrayBoolean: + return value || [] } } @@ -157,6 +147,10 @@ const ChatVariableModal = ({ setEditInJSON(true) if (v === ChatVarType.String || v === ChatVarType.Number || v === ChatVarType.Object) setEditInJSON(false) + if(v === ChatVarType.Boolean) + setValue(false) + if (v === ChatVarType.ArrayBoolean) + setValue([false]) setType(v) } @@ -202,6 +196,11 @@ const ChatVariableModal = ({ setValue(value?.length ? value : [undefined]) } } + + if(type === ChatVarType.ArrayBoolean) { + if(editInJSON) + setEditorContent(JSON.stringify(value.map((item: boolean) => item ? 'True' : 'False'))) + } setEditInJSON(editInJSON) } @@ -213,7 +212,16 @@ const ChatVariableModal = ({ else { setEditorContent(content) try { - const newValue = JSON.parse(content) + let newValue = JSON.parse(content) + if(type === ChatVarType.ArrayBoolean) { + newValue = newValue.map((item: string | boolean) => { + if (item === 'True' || item === 'true' || item === true) + return true + if (item === 'False' || item === 'false' || item === false) + return false + return undefined + }).filter((item?: boolean) => item !== undefined) + } setValue(newValue) } catch { @@ -304,7 +312,7 @@ const ChatVariableModal = ({
{t('workflow.chatVariable.modal.value')}
- {(type === ChatVarType.ArrayString || type === ChatVarType.ArrayNumber) && ( + {(type === ChatVarType.ArrayString || type === ChatVarType.ArrayNumber || type === ChatVarType.ArrayBoolean) && (