[Chore/Refactor] Improve type annotations in models module (#25281)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
-LAN-
2025-09-08 09:42:27 +08:00
committed by GitHub
parent e1f871fefe
commit 9b8a03b53b
23 changed files with 332 additions and 251 deletions

View File

@@ -3,7 +3,7 @@ import logging
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from uuid import uuid4
import sqlalchemy as sa
@@ -224,7 +224,7 @@ class Workflow(Base):
raise WorkflowDataError("nodes not found in workflow graph")
try:
node_config = next(filter(lambda node: node["id"] == node_id, nodes))
node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes))
except StopIteration:
raise NodeNotFoundError(node_id)
assert isinstance(node_config, dict)
@@ -289,7 +289,7 @@ class Workflow(Base):
def features_dict(self) -> dict[str, Any]:
return json.loads(self.features) if self.features else {}
def user_input_form(self, to_old_structure: bool = False):
def user_input_form(self, to_old_structure: bool = False) -> list[Any]:
# get start node from graph
if not self.graph:
return []
@@ -306,7 +306,7 @@ class Workflow(Base):
variables: list[Any] = start_node.get("data", {}).get("variables", [])
if to_old_structure:
old_structure_variables = []
old_structure_variables: list[dict[str, Any]] = []
for variable in variables:
old_structure_variables.append({variable["type"]: variable})
@@ -346,9 +346,7 @@ class Workflow(Base):
@property
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
# TODO: find some way to init `self._environment_variables` when instance created.
if self._environment_variables is None:
self._environment_variables = "{}"
# _environment_variables is guaranteed to be non-None due to server_default="{}"
# Use workflow.tenant_id to avoid relying on request user in background threads
tenant_id = self.tenant_id
@@ -362,17 +360,18 @@ class Workflow(Base):
]
# decrypt secret variables value
def decrypt_func(var):
def decrypt_func(var: Variable) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)):
return var
else:
raise AssertionError("this statement should be unreachable.")
# Other variable types are not supported for environment variables
raise AssertionError(f"Unexpected variable type for environment variable: {type(var)}")
decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list(
map(decrypt_func, results)
)
decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = [
decrypt_func(var) for var in results
]
return decrypted_results
@environment_variables.setter
@@ -400,7 +399,7 @@ class Workflow(Base):
value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name})
# encrypt secret variables value
def encrypt_func(var):
def encrypt_func(var: Variable) -> Variable:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)})
else:
@@ -430,9 +429,7 @@ class Workflow(Base):
@property
def conversation_variables(self) -> Sequence[Variable]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = "{}"
# _conversation_variables is guaranteed to be non-None due to server_default="{}"
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
@@ -577,7 +574,7 @@ class WorkflowRun(Base):
}
@classmethod
def from_dict(cls, data: dict) -> "WorkflowRun":
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
return cls(
id=data.get("id"),
tenant_id=data.get("tenant_id"),
@@ -662,7 +659,8 @@ class WorkflowNodeExecutionModel(Base):
__tablename__ = "workflow_node_executions"
@declared_attr
def __table_args__(cls): # noqa
@classmethod
def __table_args__(cls) -> Any:
return (
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
Index(
@@ -699,7 +697,7 @@ class WorkflowNodeExecutionModel(Base):
# MyPy may flag the following line because it doesn't recognize that
# the `declared_attr` decorator passes the receiving class as the first
# argument to this method, allowing us to reference class attributes.
cls.created_at.desc(), # type: ignore
cls.created_at.desc(),
),
)
@@ -761,15 +759,15 @@ class WorkflowNodeExecutionModel(Base):
return json.loads(self.execution_metadata) if self.execution_metadata else {}
@property
def extras(self):
def extras(self) -> dict[str, Any]:
from core.tools.tool_manager import ToolManager
extras = {}
extras: dict[str, Any] = {}
if self.execution_metadata_dict:
from core.workflow.nodes import NodeType
if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict:
tool_info = self.execution_metadata_dict["tool_info"]
tool_info: dict[str, Any] = self.execution_metadata_dict["tool_info"]
extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self.tenant_id,
provider_type=tool_info["provider_type"],
@@ -1037,7 +1035,7 @@ class WorkflowDraftVariable(Base):
# making this attribute harder to access from outside the class.
__value: Segment | None
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""
The constructor of `WorkflowDraftVariable` is not intended for
direct use outside this file. Its solo purpose is setup private state
@@ -1055,15 +1053,15 @@ class WorkflowDraftVariable(Base):
self.__value = None
def get_selector(self) -> list[str]:
selector = json.loads(self.selector)
selector: Any = json.loads(self.selector)
if not isinstance(selector, list):
logger.error(
"invalid selector loaded from database, type=%s, value=%s",
type(selector),
type(selector).__name__,
self.selector,
)
raise ValueError("invalid selector.")
return selector
return cast(list[str], selector)
def _set_selector(self, value: list[str]):
self.selector = json.dumps(value)
@@ -1086,15 +1084,17 @@ class WorkflowDraftVariable(Base):
# `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging.
if isinstance(value, dict):
if not maybe_file_object(value):
return value
return cast(Any, value)
return File.model_validate(value)
elif isinstance(value, list) and value:
first = value[0]
value_list = cast(list[Any], value)
first: Any = value_list[0]
if not maybe_file_object(first):
return value
return [File.model_validate(i) for i in value]
return cast(Any, value)
file_list: list[File] = [File.model_validate(cast(dict[str, Any], i)) for i in value_list]
return cast(Any, file_list)
else:
return value
return cast(Any, value)
@classmethod
def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment: