[Chore/Refactor] Improve type annotations in models module (#25281)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
@@ -3,7 +3,7 @@ import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@@ -224,7 +224,7 @@ class Workflow(Base):
|
||||
raise WorkflowDataError("nodes not found in workflow graph")
|
||||
|
||||
try:
|
||||
node_config = next(filter(lambda node: node["id"] == node_id, nodes))
|
||||
node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes))
|
||||
except StopIteration:
|
||||
raise NodeNotFoundError(node_id)
|
||||
assert isinstance(node_config, dict)
|
||||
@@ -289,7 +289,7 @@ class Workflow(Base):
|
||||
def features_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.features) if self.features else {}
|
||||
|
||||
def user_input_form(self, to_old_structure: bool = False):
|
||||
def user_input_form(self, to_old_structure: bool = False) -> list[Any]:
|
||||
# get start node from graph
|
||||
if not self.graph:
|
||||
return []
|
||||
@@ -306,7 +306,7 @@ class Workflow(Base):
|
||||
variables: list[Any] = start_node.get("data", {}).get("variables", [])
|
||||
|
||||
if to_old_structure:
|
||||
old_structure_variables = []
|
||||
old_structure_variables: list[dict[str, Any]] = []
|
||||
for variable in variables:
|
||||
old_structure_variables.append({variable["type"]: variable})
|
||||
|
||||
@@ -346,9 +346,7 @@ class Workflow(Base):
|
||||
|
||||
@property
|
||||
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
||||
# TODO: find some way to init `self._environment_variables` when instance created.
|
||||
if self._environment_variables is None:
|
||||
self._environment_variables = "{}"
|
||||
# _environment_variables is guaranteed to be non-None due to server_default="{}"
|
||||
|
||||
# Use workflow.tenant_id to avoid relying on request user in background threads
|
||||
tenant_id = self.tenant_id
|
||||
@@ -362,17 +360,18 @@ class Workflow(Base):
|
||||
]
|
||||
|
||||
# decrypt secret variables value
|
||||
def decrypt_func(var):
|
||||
def decrypt_func(var: Variable) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
|
||||
if isinstance(var, SecretVariable):
|
||||
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
|
||||
elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)):
|
||||
return var
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
# Other variable types are not supported for environment variables
|
||||
raise AssertionError(f"Unexpected variable type for environment variable: {type(var)}")
|
||||
|
||||
decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list(
|
||||
map(decrypt_func, results)
|
||||
)
|
||||
decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = [
|
||||
decrypt_func(var) for var in results
|
||||
]
|
||||
return decrypted_results
|
||||
|
||||
@environment_variables.setter
|
||||
@@ -400,7 +399,7 @@ class Workflow(Base):
|
||||
value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name})
|
||||
|
||||
# encrypt secret variables value
|
||||
def encrypt_func(var):
|
||||
def encrypt_func(var: Variable) -> Variable:
|
||||
if isinstance(var, SecretVariable):
|
||||
return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)})
|
||||
else:
|
||||
@@ -430,9 +429,7 @@ class Workflow(Base):
|
||||
|
||||
@property
|
||||
def conversation_variables(self) -> Sequence[Variable]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._conversation_variables is None:
|
||||
self._conversation_variables = "{}"
|
||||
# _conversation_variables is guaranteed to be non-None due to server_default="{}"
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
|
||||
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
|
||||
@@ -577,7 +574,7 @@ class WorkflowRun(Base):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "WorkflowRun":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
|
||||
return cls(
|
||||
id=data.get("id"),
|
||||
tenant_id=data.get("tenant_id"),
|
||||
@@ -662,7 +659,8 @@ class WorkflowNodeExecutionModel(Base):
|
||||
__tablename__ = "workflow_node_executions"
|
||||
|
||||
@declared_attr
|
||||
def __table_args__(cls): # noqa
|
||||
@classmethod
|
||||
def __table_args__(cls) -> Any:
|
||||
return (
|
||||
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
|
||||
Index(
|
||||
@@ -699,7 +697,7 @@ class WorkflowNodeExecutionModel(Base):
|
||||
# MyPy may flag the following line because it doesn't recognize that
|
||||
# the `declared_attr` decorator passes the receiving class as the first
|
||||
# argument to this method, allowing us to reference class attributes.
|
||||
cls.created_at.desc(), # type: ignore
|
||||
cls.created_at.desc(),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -761,15 +759,15 @@ class WorkflowNodeExecutionModel(Base):
|
||||
return json.loads(self.execution_metadata) if self.execution_metadata else {}
|
||||
|
||||
@property
|
||||
def extras(self):
|
||||
def extras(self) -> dict[str, Any]:
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
extras = {}
|
||||
extras: dict[str, Any] = {}
|
||||
if self.execution_metadata_dict:
|
||||
from core.workflow.nodes import NodeType
|
||||
|
||||
if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict:
|
||||
tool_info = self.execution_metadata_dict["tool_info"]
|
||||
tool_info: dict[str, Any] = self.execution_metadata_dict["tool_info"]
|
||||
extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_type=tool_info["provider_type"],
|
||||
@@ -1037,7 +1035,7 @@ class WorkflowDraftVariable(Base):
|
||||
# making this attribute harder to access from outside the class.
|
||||
__value: Segment | None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""
|
||||
The constructor of `WorkflowDraftVariable` is not intended for
|
||||
direct use outside this file. Its solo purpose is setup private state
|
||||
@@ -1055,15 +1053,15 @@ class WorkflowDraftVariable(Base):
|
||||
self.__value = None
|
||||
|
||||
def get_selector(self) -> list[str]:
|
||||
selector = json.loads(self.selector)
|
||||
selector: Any = json.loads(self.selector)
|
||||
if not isinstance(selector, list):
|
||||
logger.error(
|
||||
"invalid selector loaded from database, type=%s, value=%s",
|
||||
type(selector),
|
||||
type(selector).__name__,
|
||||
self.selector,
|
||||
)
|
||||
raise ValueError("invalid selector.")
|
||||
return selector
|
||||
return cast(list[str], selector)
|
||||
|
||||
def _set_selector(self, value: list[str]):
|
||||
self.selector = json.dumps(value)
|
||||
@@ -1086,15 +1084,17 @@ class WorkflowDraftVariable(Base):
|
||||
# `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging.
|
||||
if isinstance(value, dict):
|
||||
if not maybe_file_object(value):
|
||||
return value
|
||||
return cast(Any, value)
|
||||
return File.model_validate(value)
|
||||
elif isinstance(value, list) and value:
|
||||
first = value[0]
|
||||
value_list = cast(list[Any], value)
|
||||
first: Any = value_list[0]
|
||||
if not maybe_file_object(first):
|
||||
return value
|
||||
return [File.model_validate(i) for i in value]
|
||||
return cast(Any, value)
|
||||
file_list: list[File] = [File.model_validate(cast(dict[str, Any], i)) for i in value_list]
|
||||
return cast(Any, file_list)
|
||||
else:
|
||||
return value
|
||||
return cast(Any, value)
|
||||
|
||||
@classmethod
|
||||
def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment:
|
||||
|
Reference in New Issue
Block a user