feat: support bool type variable frontend (#24437)
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
@@ -119,6 +120,14 @@ class CodeNode(BaseNode):
|
||||
|
||||
return value.replace("\x00", "")
|
||||
|
||||
def _check_boolean(self, value: bool | None, variable: str) -> bool | None:
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, bool):
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a boolean")
|
||||
|
||||
return value
|
||||
|
||||
def _check_number(self, value: int | float | None, variable: str) -> int | float | None:
|
||||
"""
|
||||
Check number
|
||||
@@ -173,6 +182,8 @@ class CodeNode(BaseNode):
|
||||
prefix=f"{prefix}.{output_name}" if prefix else output_name,
|
||||
depth=depth + 1,
|
||||
)
|
||||
elif isinstance(output_value, bool):
|
||||
self._check_boolean(output_value, variable=f"{prefix}.{output_name}" if prefix else output_name)
|
||||
elif isinstance(output_value, int | float):
|
||||
self._check_number(
|
||||
value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name
|
||||
@@ -232,7 +243,7 @@ class CodeNode(BaseNode):
|
||||
if output_name not in result:
|
||||
raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.")
|
||||
|
||||
if output_config.type == "object":
|
||||
if output_config.type == SegmentType.OBJECT:
|
||||
# check if output is object
|
||||
if not isinstance(result.get(output_name), dict):
|
||||
if result[output_name] is None:
|
||||
@@ -249,18 +260,28 @@ class CodeNode(BaseNode):
|
||||
prefix=f"{prefix}.{output_name}",
|
||||
depth=depth + 1,
|
||||
)
|
||||
elif output_config.type == "number":
|
||||
elif output_config.type == SegmentType.NUMBER:
|
||||
# check if number available
|
||||
transformed_result[output_name] = self._check_number(
|
||||
value=result[output_name], variable=f"{prefix}{dot}{output_name}"
|
||||
)
|
||||
elif output_config.type == "string":
|
||||
checked = self._check_number(value=result[output_name], variable=f"{prefix}{dot}{output_name}")
|
||||
# If the output is a boolean and the output schema specifies a NUMBER type,
|
||||
# convert the boolean value to an integer.
|
||||
#
|
||||
# This ensures compatibility with existing workflows that may use
|
||||
# `True` and `False` as values for NUMBER type outputs.
|
||||
transformed_result[output_name] = self._convert_boolean_to_int(checked)
|
||||
|
||||
elif output_config.type == SegmentType.STRING:
|
||||
# check if string available
|
||||
transformed_result[output_name] = self._check_string(
|
||||
value=result[output_name],
|
||||
variable=f"{prefix}{dot}{output_name}",
|
||||
)
|
||||
elif output_config.type == "array[number]":
|
||||
elif output_config.type == SegmentType.BOOLEAN:
|
||||
transformed_result[output_name] = self._check_boolean(
|
||||
value=result[output_name],
|
||||
variable=f"{prefix}{dot}{output_name}",
|
||||
)
|
||||
elif output_config.type == SegmentType.ARRAY_NUMBER:
|
||||
# check if array of number available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
@@ -278,10 +299,17 @@ class CodeNode(BaseNode):
|
||||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
# If the element is a boolean and the output schema specifies a `array[number]` type,
|
||||
# convert the boolean value to an integer.
|
||||
#
|
||||
# This ensures compatibility with existing workflows that may use
|
||||
# `True` and `False` as values for NUMBER type outputs.
|
||||
self._convert_boolean_to_int(
|
||||
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]"),
|
||||
)
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == "array[string]":
|
||||
elif output_config.type == SegmentType.ARRAY_STRING:
|
||||
# check if array of string available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
@@ -302,7 +330,7 @@ class CodeNode(BaseNode):
|
||||
self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == "array[object]":
|
||||
elif output_config.type == SegmentType.ARRAY_OBJECT:
|
||||
# check if array of object available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
@@ -340,6 +368,22 @@ class CodeNode(BaseNode):
|
||||
)
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == SegmentType.ARRAY_BOOLEAN:
|
||||
# check if array of object available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an array,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
transformed_result[output_name] = [
|
||||
self._check_boolean(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
|
||||
else:
|
||||
raise OutputValidationError(f"Output type {output_config.type} is not supported.")
|
||||
|
||||
@@ -374,3 +418,16 @@ class CodeNode(BaseNode):
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
||||
@staticmethod
|
||||
def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None:
|
||||
"""This function convert boolean to integers when the output schema specifies a NUMBER type.
|
||||
|
||||
This ensures compatibility with existing workflows that may use
|
||||
`True` and `False` as values for NUMBER type outputs.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
return value
|
||||
|
@@ -1,11 +1,31 @@
|
||||
from typing import Literal, Optional
|
||||
from typing import Annotated, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import AfterValidator, BaseModel
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
_ALLOWED_OUTPUT_FROM_CODE = frozenset(
|
||||
[
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _validate_type(segment_type: SegmentType) -> SegmentType:
|
||||
if segment_type not in _ALLOWED_OUTPUT_FROM_CODE:
|
||||
raise ValueError(f"invalid type for code output, expected {_ALLOWED_OUTPUT_FROM_CODE}, actual {segment_type}")
|
||||
return segment_type
|
||||
|
||||
|
||||
class CodeNodeData(BaseNodeData):
|
||||
"""
|
||||
@@ -13,7 +33,7 @@ class CodeNodeData(BaseNodeData):
|
||||
"""
|
||||
|
||||
class Output(BaseModel):
|
||||
type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
|
||||
type: Annotated[SegmentType, AfterValidator(_validate_type)]
|
||||
children: Optional[dict[str, "CodeNodeData.Output"]] = None
|
||||
|
||||
class Dependency(BaseModel):
|
||||
|
@@ -1,36 +1,43 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
_Condition = Literal[
|
||||
|
||||
class FilterOperator(StrEnum):
|
||||
# string conditions
|
||||
"contains",
|
||||
"start with",
|
||||
"end with",
|
||||
"is",
|
||||
"in",
|
||||
"empty",
|
||||
"not contains",
|
||||
"is not",
|
||||
"not in",
|
||||
"not empty",
|
||||
CONTAINS = "contains"
|
||||
START_WITH = "start with"
|
||||
END_WITH = "end with"
|
||||
IS = "is"
|
||||
IN = "in"
|
||||
EMPTY = "empty"
|
||||
NOT_CONTAINS = "not contains"
|
||||
IS_NOT = "is not"
|
||||
NOT_IN = "not in"
|
||||
NOT_EMPTY = "not empty"
|
||||
# number conditions
|
||||
"=",
|
||||
"≠",
|
||||
"<",
|
||||
">",
|
||||
"≥",
|
||||
"≤",
|
||||
]
|
||||
EQUAL = "="
|
||||
NOT_EQUAL = "≠"
|
||||
LESS_THAN = "<"
|
||||
GREATER_THAN = ">"
|
||||
GREATER_THAN_OR_EQUAL = "≥"
|
||||
LESS_THAN_OR_EQUAL = "≤"
|
||||
|
||||
|
||||
class Order(StrEnum):
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
|
||||
class FilterCondition(BaseModel):
|
||||
key: str = ""
|
||||
comparison_operator: _Condition = "contains"
|
||||
value: str | Sequence[str] = ""
|
||||
comparison_operator: FilterOperator = FilterOperator.CONTAINS
|
||||
# the value is bool if the filter operator is comparing with
|
||||
# a boolean constant.
|
||||
value: str | Sequence[str] | bool = ""
|
||||
|
||||
|
||||
class FilterBy(BaseModel):
|
||||
@@ -38,10 +45,10 @@ class FilterBy(BaseModel):
|
||||
conditions: Sequence[FilterCondition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OrderBy(BaseModel):
|
||||
class OrderByConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
key: str = ""
|
||||
value: Literal["asc", "desc"] = "asc"
|
||||
value: Order = Order.ASC
|
||||
|
||||
|
||||
class Limit(BaseModel):
|
||||
@@ -57,6 +64,6 @@ class ExtractConfig(BaseModel):
|
||||
class ListOperatorNodeData(BaseNodeData):
|
||||
variable: Sequence[str] = Field(default_factory=list)
|
||||
filter_by: FilterBy
|
||||
order_by: OrderBy
|
||||
order_by: OrderByConfig
|
||||
limit: Limit
|
||||
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)
|
||||
|
@@ -1,18 +1,40 @@
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Optional, TypeAlias, TypeVar
|
||||
|
||||
from core.file import File
|
||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
|
||||
from .entities import ListOperatorNodeData
|
||||
from .entities import FilterOperator, ListOperatorNodeData, Order
|
||||
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
|
||||
|
||||
_SUPPORTED_TYPES_TUPLE = (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayStringSegment,
|
||||
ArrayBooleanSegment,
|
||||
)
|
||||
_SUPPORTED_TYPES_ALIAS: TypeAlias = ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment | ArrayBooleanSegment
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
|
||||
"""Returns the negation of a given filter function. If the original filter
|
||||
returns `True` for a value, the negated filter will return `False`, and vice versa.
|
||||
"""
|
||||
|
||||
def wrapper(value: _T) -> bool:
|
||||
return not filter_(value)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ListOperatorNode(BaseNode):
|
||||
_node_type = NodeType.LIST_OPERATOR
|
||||
@@ -69,11 +91,8 @@ class ListOperatorNode(BaseNode):
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
)
|
||||
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
|
||||
error_message = (
|
||||
f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
|
||||
"or ArrayStringSegment"
|
||||
)
|
||||
if not isinstance(variable, _SUPPORTED_TYPES_TUPLE):
|
||||
error_message = f"Variable {self._node_data.variable} is not an array type, actual type: {type(variable)}"
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
|
||||
)
|
||||
@@ -122,9 +141,7 @@ class ListOperatorNode(BaseNode):
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
def _apply_filter(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
filter_func: Callable[[Any], bool]
|
||||
result: list[Any] = []
|
||||
for condition in self._node_data.filter_by.conditions:
|
||||
@@ -154,33 +171,35 @@ class ListOperatorNode(BaseNode):
|
||||
)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayBooleanSegment):
|
||||
if not isinstance(condition.value, bool):
|
||||
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
|
||||
filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
else:
|
||||
raise AssertionError("this statment should be unreachable.")
|
||||
return variable
|
||||
|
||||
def _apply_order(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
if isinstance(variable, ArrayStringSegment):
|
||||
result = _order_string(order=self._node_data.order_by.value, array=variable.value)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayNumberSegment):
|
||||
result = _order_number(order=self._node_data.order_by.value, array=variable.value)
|
||||
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
|
||||
result = sorted(variable.value, reverse=self._node_data.order_by == Order.DESC)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
result = _order_file(
|
||||
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
|
||||
)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable")
|
||||
|
||||
return variable
|
||||
|
||||
def _apply_slice(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
result = variable.value[: self._node_data.limit.size]
|
||||
return variable.model_copy(update={"value": result})
|
||||
|
||||
def _extract_slice(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
|
||||
if value < 1:
|
||||
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
|
||||
@@ -232,11 +251,11 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo
|
||||
case "empty":
|
||||
return lambda x: x == ""
|
||||
case "not contains":
|
||||
return lambda x: not _contains(value)(x)
|
||||
return _negation(_contains(value))
|
||||
case "is not":
|
||||
return lambda x: not _is(value)(x)
|
||||
return _negation(_is(value))
|
||||
case "not in":
|
||||
return lambda x: not _in(value)(x)
|
||||
return _negation(_in(value))
|
||||
case "not empty":
|
||||
return lambda x: x != ""
|
||||
case _:
|
||||
@@ -248,7 +267,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab
|
||||
case "in":
|
||||
return _in(value)
|
||||
case "not in":
|
||||
return lambda x: not _in(value)(x)
|
||||
return _negation(_in(value))
|
||||
case _:
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
@@ -271,6 +290,16 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Callable[[bool], bool]:
|
||||
match condition:
|
||||
case FilterOperator.IS:
|
||||
return _is(value)
|
||||
case FilterOperator.IS_NOT:
|
||||
return _negation(_is(value))
|
||||
case _:
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
|
||||
extract_func: Callable[[File], Any]
|
||||
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
|
||||
@@ -298,7 +327,7 @@ def _endswith(value: str) -> Callable[[str], bool]:
|
||||
return lambda x: x.endswith(value)
|
||||
|
||||
|
||||
def _is(value: str) -> Callable[[str], bool]:
|
||||
def _is(value: _T) -> Callable[[_T], bool]:
|
||||
return lambda x: x == value
|
||||
|
||||
|
||||
@@ -330,21 +359,13 @@ def _ge(value: int | float) -> Callable[[int | float], bool]:
|
||||
return lambda x: x >= value
|
||||
|
||||
|
||||
def _order_number(*, order: Literal["asc", "desc"], array: Sequence[int | float]):
|
||||
return sorted(array, key=lambda x: x, reverse=order == "desc")
|
||||
|
||||
|
||||
def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]):
|
||||
return sorted(array, key=lambda x: x, reverse=order == "desc")
|
||||
|
||||
|
||||
def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]):
|
||||
def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]):
|
||||
extract_func: Callable[[File], Any]
|
||||
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}:
|
||||
extract_func = _get_file_extract_string_func(key=order_by)
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
|
||||
elif order_by == "size":
|
||||
extract_func = _get_file_extract_number_func(key=order_by)
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
|
||||
else:
|
||||
raise InvalidKeyError(f"Invalid order key: {order_by}")
|
||||
|
@@ -3,7 +3,7 @@ import io
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import FileType, file_manager
|
||||
@@ -55,7 +55,6 @@ from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
@@ -90,6 +89,7 @@ from .file_saver import FileSaverImpl, LLMFileSaver
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -161,7 +161,7 @@ class LLMNode(BaseNode):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
def _run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
|
||||
node_inputs: Optional[dict[str, Any]] = None
|
||||
process_data = None
|
||||
result_text = ""
|
||||
|
@@ -12,9 +12,11 @@ _VALID_VAR_TYPE = frozenset(
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
)
|
||||
|
||||
|
@@ -404,11 +404,11 @@ class LoopNode(BaseNode):
|
||||
for node_id in loop_graph.node_ids:
|
||||
variable_pool.remove([node_id])
|
||||
|
||||
_outputs = {}
|
||||
_outputs: dict[str, Segment | int | None] = {}
|
||||
for loop_variable_key, loop_variable_selector in loop_variable_selectors.items():
|
||||
_loop_variable_segment = variable_pool.get(loop_variable_selector)
|
||||
if _loop_variable_segment:
|
||||
_outputs[loop_variable_key] = _loop_variable_segment.value
|
||||
_outputs[loop_variable_key] = _loop_variable_segment
|
||||
else:
|
||||
_outputs[loop_variable_key] = None
|
||||
|
||||
@@ -522,21 +522,30 @@ class LoopNode(BaseNode):
|
||||
return variable_mapping
|
||||
|
||||
@staticmethod
|
||||
def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment:
|
||||
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
|
||||
"""Get the appropriate segment type for a constant value."""
|
||||
if var_type in ["array[string]", "array[number]", "array[object]"]:
|
||||
if value and isinstance(value, str):
|
||||
value = json.loads(value)
|
||||
if var_type in [
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
]:
|
||||
if original_value and isinstance(original_value, str):
|
||||
value = json.loads(original_value)
|
||||
else:
|
||||
logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type)
|
||||
value = []
|
||||
elif var_type == SegmentType.ARRAY_BOOLEAN:
|
||||
value = original_value
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
try:
|
||||
return build_segment_with_type(var_type, value)
|
||||
return build_segment_with_type(var_type, value=value)
|
||||
except TypeMismatchError as type_exc:
|
||||
# Attempt to parse the value as a JSON-encoded string, if applicable.
|
||||
if not isinstance(value, str):
|
||||
if not isinstance(original_value, str):
|
||||
raise
|
||||
try:
|
||||
value = json.loads(value)
|
||||
value = json.loads(original_value)
|
||||
except ValueError:
|
||||
raise type_exc
|
||||
return build_segment_with_type(var_type, value)
|
||||
|
@@ -1,10 +1,46 @@
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Annotated, Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
Field,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm import ModelConfig, VisionConfig
|
||||
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
|
||||
_OLD_BOOL_TYPE_NAME = "bool"
|
||||
_OLD_SELECT_TYPE_NAME = "select"
|
||||
|
||||
_VALID_PARAMETER_TYPES = frozenset(
|
||||
[
|
||||
SegmentType.STRING, # "string",
|
||||
SegmentType.NUMBER, # "number",
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
_OLD_BOOL_TYPE_NAME, # old boolean type used by Parameter Extractor node
|
||||
_OLD_SELECT_TYPE_NAME, # string type with enumeration choices.
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _validate_type(parameter_type: str) -> SegmentType:
|
||||
if not isinstance(parameter_type, str):
|
||||
raise TypeError(f"type should be str, got {type(parameter_type)}, value={parameter_type}")
|
||||
if parameter_type not in _VALID_PARAMETER_TYPES:
|
||||
raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.")
|
||||
|
||||
if parameter_type == _OLD_BOOL_TYPE_NAME:
|
||||
return SegmentType.BOOLEAN
|
||||
elif parameter_type == _OLD_SELECT_TYPE_NAME:
|
||||
return SegmentType.STRING
|
||||
return SegmentType(parameter_type)
|
||||
|
||||
|
||||
class _ParameterConfigError(Exception):
|
||||
@@ -17,7 +53,7 @@ class ParameterConfig(BaseModel):
|
||||
"""
|
||||
|
||||
name: str
|
||||
type: Literal["string", "number", "bool", "select", "array[string]", "array[number]", "array[object]"]
|
||||
type: Annotated[SegmentType, BeforeValidator(_validate_type)]
|
||||
options: Optional[list[str]] = None
|
||||
description: str
|
||||
required: bool
|
||||
@@ -32,17 +68,20 @@ class ParameterConfig(BaseModel):
|
||||
return str(value)
|
||||
|
||||
def is_array_type(self) -> bool:
|
||||
return self.type in ("array[string]", "array[number]", "array[object]")
|
||||
return self.type.is_array_type()
|
||||
|
||||
def element_type(self) -> Literal["string", "number", "object"]:
|
||||
if self.type == "array[number]":
|
||||
return "number"
|
||||
elif self.type == "array[string]":
|
||||
return "string"
|
||||
elif self.type == "array[object]":
|
||||
return "object"
|
||||
else:
|
||||
raise _ParameterConfigError(f"{self.type} is not array type.")
|
||||
def element_type(self) -> SegmentType:
|
||||
"""Return the element type of the parameter.
|
||||
|
||||
Raises a ValueError if the parameter's type is not an array type.
|
||||
"""
|
||||
element_type = self.type.element_type()
|
||||
# At this point, self.type is guaranteed to be one of `ARRAY_STRING`,
|
||||
# `ARRAY_NUMBER`, `ARRAY_OBJECT`, or `ARRAY_BOOLEAN`.
|
||||
#
|
||||
# See: _VALID_PARAMETER_TYPES for reference.
|
||||
assert element_type is not None, f"the element type should not be None, {self.type=}"
|
||||
return element_type
|
||||
|
||||
|
||||
class ParameterExtractorNodeData(BaseNodeData):
|
||||
@@ -74,16 +113,18 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
for parameter in self.parameters:
|
||||
parameter_schema: dict[str, Any] = {"description": parameter.description}
|
||||
|
||||
if parameter.type in {"string", "select"}:
|
||||
if parameter.type == SegmentType.STRING:
|
||||
parameter_schema["type"] = "string"
|
||||
elif parameter.type.startswith("array"):
|
||||
elif parameter.type.is_array_type():
|
||||
parameter_schema["type"] = "array"
|
||||
nested_type = parameter.type[6:-1]
|
||||
parameter_schema["items"] = {"type": nested_type}
|
||||
element_type = parameter.type.element_type()
|
||||
if element_type is None:
|
||||
raise AssertionError("element type should not be None.")
|
||||
parameter_schema["items"] = {"type": element_type.value}
|
||||
else:
|
||||
parameter_schema["type"] = parameter.type
|
||||
|
||||
if parameter.type == "select":
|
||||
if parameter.options:
|
||||
parameter_schema["enum"] = parameter.options
|
||||
|
||||
parameters["properties"][parameter.name] = parameter_schema
|
||||
|
@@ -1,3 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
|
||||
class ParameterExtractorNodeError(ValueError):
|
||||
"""Base error for ParameterExtractorNode."""
|
||||
|
||||
@@ -48,3 +53,23 @@ class InvalidArrayValueError(ParameterExtractorNodeError):
|
||||
|
||||
class InvalidModelModeError(ParameterExtractorNodeError):
|
||||
"""Raised when the model mode is invalid."""
|
||||
|
||||
|
||||
class InvalidValueTypeError(ParameterExtractorNodeError):
|
||||
def __init__(
|
||||
self,
|
||||
/,
|
||||
parameter_name: str,
|
||||
expected_type: SegmentType,
|
||||
actual_type: SegmentType | None,
|
||||
value: Any,
|
||||
) -> None:
|
||||
message = (
|
||||
f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, "
|
||||
f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}"
|
||||
)
|
||||
super().__init__(message)
|
||||
self.parameter_name = parameter_name
|
||||
self.expected_type = expected_type
|
||||
self.actual_type = actual_type
|
||||
self.value = value
|
||||
|
@@ -26,7 +26,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.variables.types import SegmentType
|
||||
from core.variables.types import ArrayValidation, SegmentType
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
@@ -39,16 +39,13 @@ from factories.variable_factory import build_segment_with_type
|
||||
|
||||
from .entities import ParameterExtractorNodeData
|
||||
from .exc import (
|
||||
InvalidArrayValueError,
|
||||
InvalidBoolValueError,
|
||||
InvalidInvokeResultError,
|
||||
InvalidModelModeError,
|
||||
InvalidModelTypeError,
|
||||
InvalidNumberOfParametersError,
|
||||
InvalidNumberValueError,
|
||||
InvalidSelectValueError,
|
||||
InvalidStringValueError,
|
||||
InvalidTextContentTypeError,
|
||||
InvalidValueTypeError,
|
||||
ModelSchemaNotFoundError,
|
||||
ParameterExtractorNodeError,
|
||||
RequiredParameterMissingError,
|
||||
@@ -549,9 +546,6 @@ class ParameterExtractorNode(BaseNode):
|
||||
return prompt_messages
|
||||
|
||||
def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
||||
"""
|
||||
Validate result.
|
||||
"""
|
||||
if len(data.parameters) != len(result):
|
||||
raise InvalidNumberOfParametersError("Invalid number of parameters")
|
||||
|
||||
@@ -559,101 +553,106 @@ class ParameterExtractorNode(BaseNode):
|
||||
if parameter.required and parameter.name not in result:
|
||||
raise RequiredParameterMissingError(f"Parameter {parameter.name} is required")
|
||||
|
||||
if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options:
|
||||
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float):
|
||||
raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool):
|
||||
raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == "string" and not isinstance(result.get(parameter.name), str):
|
||||
raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type.startswith("array"):
|
||||
parameters = result.get(parameter.name)
|
||||
if not isinstance(parameters, list):
|
||||
raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}")
|
||||
nested_type = parameter.type[6:-1]
|
||||
for item in parameters:
|
||||
if nested_type == "number" and not isinstance(item, int | float):
|
||||
raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
|
||||
if nested_type == "string" and not isinstance(item, str):
|
||||
raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
|
||||
if nested_type == "object" and not isinstance(item, dict):
|
||||
raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
|
||||
param_value = result.get(parameter.name)
|
||||
if not parameter.type.is_valid(param_value, array_validation=ArrayValidation.ALL):
|
||||
inferred_type = SegmentType.infer_segment_type(param_value)
|
||||
raise InvalidValueTypeError(
|
||||
parameter_name=parameter.name,
|
||||
expected_type=parameter.type,
|
||||
actual_type=inferred_type,
|
||||
value=param_value,
|
||||
)
|
||||
if parameter.type == SegmentType.STRING and parameter.options:
|
||||
if param_value not in parameter.options:
|
||||
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _transform_number(value: int | float | str | bool) -> int | float | None:
|
||||
"""
|
||||
Attempts to transform the input into an integer or float.
|
||||
|
||||
Returns:
|
||||
int or float: The transformed number if the conversion is successful.
|
||||
None: If the transformation fails.
|
||||
|
||||
Note:
|
||||
Boolean values `True` and `False` are converted to integers `1` and `0`, respectively.
|
||||
This behavior ensures compatibility with existing workflows that may use boolean types as integers.
|
||||
"""
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
elif isinstance(value, (int, float)):
|
||||
return value
|
||||
elif not isinstance(value, str):
|
||||
return None
|
||||
if "." in value:
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
||||
"""
|
||||
Transform result into standard format.
|
||||
"""
|
||||
transformed_result = {}
|
||||
transformed_result: dict[str, Any] = {}
|
||||
for parameter in data.parameters:
|
||||
if parameter.name in result:
|
||||
param_value = result[parameter.name]
|
||||
# transform value
|
||||
if parameter.type == "number":
|
||||
if isinstance(result[parameter.name], int | float):
|
||||
transformed_result[parameter.name] = result[parameter.name]
|
||||
elif isinstance(result[parameter.name], str):
|
||||
try:
|
||||
if "." in result[parameter.name]:
|
||||
result[parameter.name] = float(result[parameter.name])
|
||||
else:
|
||||
result[parameter.name] = int(result[parameter.name])
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
# TODO: bool is not supported in the current version
|
||||
# elif parameter.type == 'bool':
|
||||
# if isinstance(result[parameter.name], bool):
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
# elif isinstance(result[parameter.name], str):
|
||||
# if result[parameter.name].lower() in ['true', 'false']:
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true')
|
||||
# elif isinstance(result[parameter.name], int):
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
elif parameter.type in {"string", "select"}:
|
||||
if isinstance(result[parameter.name], str):
|
||||
transformed_result[parameter.name] = result[parameter.name]
|
||||
if parameter.type == SegmentType.NUMBER:
|
||||
transformed = self._transform_number(param_value)
|
||||
if transformed is not None:
|
||||
transformed_result[parameter.name] = transformed
|
||||
elif parameter.type == SegmentType.BOOLEAN:
|
||||
if isinstance(result[parameter.name], (bool, int)):
|
||||
transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
# elif isinstance(result[parameter.name], str):
|
||||
# if result[parameter.name].lower() in ["true", "false"]:
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == "true")
|
||||
elif parameter.type == SegmentType.STRING:
|
||||
if isinstance(param_value, str):
|
||||
transformed_result[parameter.name] = param_value
|
||||
elif parameter.is_array_type():
|
||||
if isinstance(result[parameter.name], list):
|
||||
if isinstance(param_value, list):
|
||||
nested_type = parameter.element_type()
|
||||
assert nested_type is not None
|
||||
segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[])
|
||||
transformed_result[parameter.name] = segment_value
|
||||
for item in result[parameter.name]:
|
||||
if nested_type == "number":
|
||||
if isinstance(item, int | float):
|
||||
segment_value.value.append(item)
|
||||
elif isinstance(item, str):
|
||||
try:
|
||||
if "." in item:
|
||||
segment_value.value.append(float(item))
|
||||
else:
|
||||
segment_value.value.append(int(item))
|
||||
except ValueError:
|
||||
pass
|
||||
elif nested_type == "string":
|
||||
for item in param_value:
|
||||
if nested_type == SegmentType.NUMBER:
|
||||
transformed = self._transform_number(item)
|
||||
if transformed is not None:
|
||||
segment_value.value.append(transformed)
|
||||
elif nested_type == SegmentType.STRING:
|
||||
if isinstance(item, str):
|
||||
segment_value.value.append(item)
|
||||
elif nested_type == "object":
|
||||
elif nested_type == SegmentType.OBJECT:
|
||||
if isinstance(item, dict):
|
||||
segment_value.value.append(item)
|
||||
elif nested_type == SegmentType.BOOLEAN:
|
||||
if isinstance(item, bool):
|
||||
segment_value.value.append(item)
|
||||
|
||||
if parameter.name not in transformed_result:
|
||||
if parameter.type == "number":
|
||||
transformed_result[parameter.name] = 0
|
||||
elif parameter.type == "bool":
|
||||
transformed_result[parameter.name] = False
|
||||
elif parameter.type in {"string", "select"}:
|
||||
transformed_result[parameter.name] = ""
|
||||
elif parameter.type.startswith("array"):
|
||||
if parameter.type.is_array_type():
|
||||
transformed_result[parameter.name] = build_segment_with_type(
|
||||
segment_type=SegmentType(parameter.type), value=[]
|
||||
)
|
||||
elif parameter.type in (SegmentType.STRING, SegmentType.SECRET):
|
||||
transformed_result[parameter.name] = ""
|
||||
elif parameter.type == SegmentType.NUMBER:
|
||||
transformed_result[parameter.name] = 0
|
||||
elif parameter.type == SegmentType.BOOLEAN:
|
||||
transformed_result[parameter.name] = False
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
return transformed_result
|
||||
|
||||
|
@@ -2,6 +2,7 @@ from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeAlias
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables.segments import BooleanSegment
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
@@ -158,8 +159,8 @@ class VariableAssignerNode(BaseNode):
|
||||
def get_zero_value(t: SegmentType):
|
||||
# TODO(QuantumGhost): this should be a method of `SegmentType`.
|
||||
match t:
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
|
||||
return variable_factory.build_segment([])
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER | SegmentType.ARRAY_BOOLEAN:
|
||||
return variable_factory.build_segment_with_type(t, [])
|
||||
case SegmentType.OBJECT:
|
||||
return variable_factory.build_segment({})
|
||||
case SegmentType.STRING:
|
||||
@@ -170,5 +171,7 @@ def get_zero_value(t: SegmentType):
|
||||
return variable_factory.build_segment(0.0)
|
||||
case SegmentType.NUMBER:
|
||||
return variable_factory.build_segment(0)
|
||||
case SegmentType.BOOLEAN:
|
||||
return BooleanSegment(value=False)
|
||||
case _:
|
||||
raise VariableOperatorNodeError(f"unsupported variable type: {t}")
|
||||
|
@@ -4,9 +4,11 @@ from core.variables import SegmentType
|
||||
EMPTY_VALUE_MAPPING = {
|
||||
SegmentType.STRING: "",
|
||||
SegmentType.NUMBER: 0,
|
||||
SegmentType.BOOLEAN: False,
|
||||
SegmentType.OBJECT: {},
|
||||
SegmentType.ARRAY_ANY: [],
|
||||
SegmentType.ARRAY_STRING: [],
|
||||
SegmentType.ARRAY_NUMBER: [],
|
||||
SegmentType.ARRAY_OBJECT: [],
|
||||
SegmentType.ARRAY_BOOLEAN: [],
|
||||
}
|
||||
|
@@ -16,28 +16,15 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.INTEGER,
|
||||
SegmentType.FLOAT,
|
||||
SegmentType.BOOLEAN,
|
||||
}
|
||||
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
|
||||
# Only number variable can be added, subtracted, multiplied or divided
|
||||
return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}
|
||||
case Operation.APPEND | Operation.EXTEND:
|
||||
case Operation.APPEND | Operation.EXTEND | Operation.REMOVE_FIRST | Operation.REMOVE_LAST:
|
||||
# Only array variable can be appended or extended
|
||||
return variable_type in {
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_FILE,
|
||||
}
|
||||
case Operation.REMOVE_FIRST | Operation.REMOVE_LAST:
|
||||
# Only array variable can have elements removed
|
||||
return variable_type in {
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_FILE,
|
||||
}
|
||||
return variable_type.is_array_type()
|
||||
case _:
|
||||
return False
|
||||
|
||||
@@ -50,7 +37,7 @@ def is_variable_input_supported(*, operation: Operation):
|
||||
|
||||
def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation):
|
||||
match variable_type:
|
||||
case SegmentType.STRING | SegmentType.OBJECT:
|
||||
case SegmentType.STRING | SegmentType.OBJECT | SegmentType.BOOLEAN:
|
||||
return operation in {Operation.OVER_WRITE, Operation.SET}
|
||||
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
|
||||
return operation in {
|
||||
@@ -72,6 +59,9 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
|
||||
case SegmentType.STRING:
|
||||
return isinstance(value, str)
|
||||
|
||||
case SegmentType.BOOLEAN:
|
||||
return isinstance(value, bool)
|
||||
|
||||
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
|
||||
if not isinstance(value, int | float):
|
||||
return False
|
||||
@@ -91,6 +81,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
|
||||
return isinstance(value, int | float)
|
||||
case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND:
|
||||
return isinstance(value, dict)
|
||||
case SegmentType.ARRAY_BOOLEAN if operation == Operation.APPEND:
|
||||
return isinstance(value, bool)
|
||||
|
||||
# Array & Extend / Overwrite
|
||||
case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
@@ -101,6 +93,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
|
||||
return isinstance(value, list) and all(isinstance(item, int | float) for item in value)
|
||||
case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, dict) for item in value)
|
||||
case SegmentType.ARRAY_BOOLEAN if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, bool) for item in value)
|
||||
|
||||
case _:
|
||||
return False
|
||||
|
@@ -45,5 +45,5 @@ class SubVariableCondition(BaseModel):
|
||||
class Condition(BaseModel):
|
||||
variable_selector: list[str]
|
||||
comparison_operator: SupportedComparisonOperator
|
||||
value: str | Sequence[str] | None = None
|
||||
value: str | Sequence[str] | bool | None = None
|
||||
sub_variable_condition: SubVariableCondition | None = None
|
||||
|
@@ -1,13 +1,27 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from core.file import FileAttribute, file_manager
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.segments import ArrayBooleanSegment, BooleanSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
from .entities import Condition, SubCondition, SupportedComparisonOperator
|
||||
|
||||
|
||||
def _convert_to_bool(value: Any) -> bool:
|
||||
if isinstance(value, int):
|
||||
return bool(value)
|
||||
|
||||
if isinstance(value, str):
|
||||
loaded = json.loads(value)
|
||||
if isinstance(loaded, (int, bool)):
|
||||
return bool(loaded)
|
||||
|
||||
raise TypeError(f"unexpected value: type={type(value)}, value={value}")
|
||||
|
||||
|
||||
class ConditionProcessor:
|
||||
def process_conditions(
|
||||
self,
|
||||
@@ -48,9 +62,16 @@ class ConditionProcessor:
|
||||
)
|
||||
else:
|
||||
actual_value = variable.value if variable else None
|
||||
expected_value = condition.value
|
||||
expected_value: str | Sequence[str] | bool | list[bool] | None = condition.value
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = variable_pool.convert_template(expected_value).text
|
||||
# Here we need to explicit convet the input string to boolean.
|
||||
if isinstance(variable, (BooleanSegment, ArrayBooleanSegment)) and expected_value is not None:
|
||||
# The following two lines is for compatibility with existing workflows.
|
||||
if isinstance(expected_value, list):
|
||||
expected_value = [_convert_to_bool(i) for i in expected_value]
|
||||
else:
|
||||
expected_value = _convert_to_bool(expected_value)
|
||||
input_conditions.append(
|
||||
{
|
||||
"actual_value": actual_value,
|
||||
@@ -77,7 +98,7 @@ def _evaluate_condition(
|
||||
*,
|
||||
operator: SupportedComparisonOperator,
|
||||
value: Any,
|
||||
expected: str | Sequence[str] | None,
|
||||
expected: Union[str, Sequence[str], bool | Sequence[bool], None],
|
||||
) -> bool:
|
||||
match operator:
|
||||
case "contains":
|
||||
@@ -130,7 +151,7 @@ def _assert_contains(*, value: Any, expected: Any) -> bool:
|
||||
if not value:
|
||||
return False
|
||||
|
||||
if not isinstance(value, str | list):
|
||||
if not isinstance(value, (str, list)):
|
||||
raise ValueError("Invalid actual value type: string or array")
|
||||
|
||||
if expected not in value:
|
||||
@@ -142,7 +163,7 @@ def _assert_not_contains(*, value: Any, expected: Any) -> bool:
|
||||
if not value:
|
||||
return True
|
||||
|
||||
if not isinstance(value, str | list):
|
||||
if not isinstance(value, (str, list)):
|
||||
raise ValueError("Invalid actual value type: string or array")
|
||||
|
||||
if expected in value:
|
||||
@@ -178,8 +199,8 @@ def _assert_is(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("Invalid actual value type: string")
|
||||
if not isinstance(value, (str, bool)):
|
||||
raise ValueError("Invalid actual value type: string or boolean")
|
||||
|
||||
if value != expected:
|
||||
return False
|
||||
@@ -190,8 +211,8 @@ def _assert_is_not(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("Invalid actual value type: string")
|
||||
if not isinstance(value, (str, bool)):
|
||||
raise ValueError("Invalid actual value type: string or boolean")
|
||||
|
||||
if value == expected:
|
||||
return False
|
||||
@@ -214,10 +235,13 @@ def _assert_equal(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
if not isinstance(value, (int, float, bool)):
|
||||
raise ValueError("Invalid actual value type: number or boolean")
|
||||
|
||||
if isinstance(value, int):
|
||||
# Handle boolean comparison
|
||||
if isinstance(value, bool):
|
||||
expected = bool(expected)
|
||||
elif isinstance(value, int):
|
||||
expected = int(expected)
|
||||
else:
|
||||
expected = float(expected)
|
||||
@@ -231,10 +255,13 @@ def _assert_not_equal(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
if not isinstance(value, (int, float, bool)):
|
||||
raise ValueError("Invalid actual value type: number or boolean")
|
||||
|
||||
if isinstance(value, int):
|
||||
# Handle boolean comparison
|
||||
if isinstance(value, bool):
|
||||
expected = bool(expected)
|
||||
elif isinstance(value, int):
|
||||
expected = int(expected)
|
||||
else:
|
||||
expected = float(expected)
|
||||
@@ -248,7 +275,7 @@ def _assert_greater_than(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
@@ -265,7 +292,7 @@ def _assert_less_than(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
@@ -282,7 +309,7 @@ def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
@@ -299,7 +326,7 @@ def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
|
Reference in New Issue
Block a user