[Chore/Refactor] Improve type annotations in models module (#25281)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
-LAN-
2025-09-08 09:42:27 +08:00
committed by GitHub
parent e1f871fefe
commit 9b8a03b53b
23 changed files with 332 additions and 251 deletions

View File

@@ -1,10 +1,10 @@
import enum
import json
from datetime import datetime
from typing import Optional
from typing import Any, Optional
import sqlalchemy as sa
from flask_login import UserMixin
from flask_login import UserMixin # type: ignore[import-untyped]
from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor
@@ -225,11 +225,11 @@ class Tenant(Base):
)
@property
def custom_config_dict(self):
def custom_config_dict(self) -> dict[str, Any]:
return json.loads(self.custom_config) if self.custom_config else {}
@custom_config_dict.setter
def custom_config_dict(self, value: dict):
def custom_config_dict(self, value: dict[str, Any]) -> None:
self.custom_config = json.dumps(value)

View File

@@ -286,7 +286,7 @@ class DatasetProcessRule(Base):
"segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50},
}
def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"dataset_id": self.dataset_id,
@@ -295,7 +295,7 @@ class DatasetProcessRule(Base):
}
@property
def rules_dict(self):
def rules_dict(self) -> dict[str, Any] | None:
try:
return json.loads(self.rules) if self.rules else None
except JSONDecodeError:
@@ -392,10 +392,10 @@ class Document(Base):
return status
@property
def data_source_info_dict(self):
def data_source_info_dict(self) -> dict[str, Any] | None:
if self.data_source_info:
try:
data_source_info_dict = json.loads(self.data_source_info)
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
except JSONDecodeError:
data_source_info_dict = {}
@@ -403,10 +403,10 @@ class Document(Base):
return None
@property
def data_source_detail_dict(self):
def data_source_detail_dict(self) -> dict[str, Any]:
if self.data_source_info:
if self.data_source_type == "upload_file":
data_source_info_dict = json.loads(self.data_source_info)
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.id == data_source_info_dict["upload_file_id"])
@@ -425,7 +425,8 @@ class Document(Base):
}
}
elif self.data_source_type in {"notion_import", "website_crawl"}:
return json.loads(self.data_source_info)
result: dict[str, Any] = json.loads(self.data_source_info)
return result
return {}
@property
@@ -471,7 +472,7 @@ class Document(Base):
return self.updated_at
@property
def doc_metadata_details(self):
def doc_metadata_details(self) -> list[dict[str, Any]] | None:
if self.doc_metadata:
document_metadatas = (
db.session.query(DatasetMetadata)
@@ -481,9 +482,9 @@ class Document(Base):
)
.all()
)
metadata_list = []
metadata_list: list[dict[str, Any]] = []
for metadata in document_metadatas:
metadata_dict = {
metadata_dict: dict[str, Any] = {
"id": metadata.id,
"name": metadata.name,
"type": metadata.type,
@@ -497,13 +498,13 @@ class Document(Base):
return None
@property
def process_rule_dict(self):
if self.dataset_process_rule_id:
def process_rule_dict(self) -> dict[str, Any] | None:
if self.dataset_process_rule_id and self.dataset_process_rule:
return self.dataset_process_rule.to_dict()
return None
def get_built_in_fields(self):
built_in_fields = []
def get_built_in_fields(self) -> list[dict[str, Any]]:
built_in_fields: list[dict[str, Any]] = []
built_in_fields.append(
{
"id": "built-in",
@@ -546,7 +547,7 @@ class Document(Base):
)
return built_in_fields
def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"tenant_id": self.tenant_id,
@@ -592,13 +593,13 @@ class Document(Base):
"data_source_info_dict": self.data_source_info_dict,
"average_segment_length": self.average_segment_length,
"dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
"dataset": self.dataset.to_dict() if self.dataset else None,
"dataset": None, # Dataset class doesn't have a to_dict method
"segment_count": self.segment_count,
"hit_count": self.hit_count,
}
@classmethod
def from_dict(cls, data: dict):
def from_dict(cls, data: dict[str, Any]):
return cls(
id=data.get("id"),
tenant_id=data.get("tenant_id"),
@@ -711,46 +712,48 @@ class DocumentSegment(Base):
)
@property
def child_chunks(self):
process_rule = self.document.dataset_process_rule
if process_rule.mode == "hierarchical":
rules = Rule(**process_rule.rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
else:
return []
else:
def child_chunks(self) -> list[Any]:
if not self.document:
return []
process_rule = self.document.dataset_process_rule
if process_rule and process_rule.mode == "hierarchical":
rules_dict = process_rule.rules_dict
if rules_dict:
rules = Rule(**rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
return []
def get_child_chunks(self):
process_rule = self.document.dataset_process_rule
if process_rule.mode == "hierarchical":
rules = Rule(**process_rule.rules_dict)
if rules.parent_mode:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
else:
return []
else:
def get_child_chunks(self) -> list[Any]:
if not self.document:
return []
process_rule = self.document.dataset_process_rule
if process_rule and process_rule.mode == "hierarchical":
rules_dict = process_rule.rules_dict
if rules_dict:
rules = Rule(**rules_dict)
if rules.parent_mode:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
return []
@property
def sign_content(self):
def sign_content(self) -> str:
return self.get_sign_content()
def get_sign_content(self):
signed_urls = []
def get_sign_content(self) -> str:
signed_urls: list[tuple[int, int, str]] = []
text = self.content
# For data before v0.10.0
@@ -890,17 +893,22 @@ class DatasetKeywordTable(Base):
)
@property
def keyword_table_dict(self):
def keyword_table_dict(self) -> dict[str, set[Any]] | None:
class SetDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
super().__init__(object_hook=self.object_hook, *args, **kwargs)
def __init__(self, *args: Any, **kwargs: Any) -> None:
def object_hook(dct: Any) -> Any:
if isinstance(dct, dict):
result: dict[str, Any] = {}
items = cast(dict[str, Any], dct).items()
for keyword, node_idxs in items:
if isinstance(node_idxs, list):
result[keyword] = set(cast(list[Any], node_idxs))
else:
result[keyword] = node_idxs
return result
return dct
def object_hook(self, dct):
if isinstance(dct, dict):
for keyword, node_idxs in dct.items():
if isinstance(node_idxs, list):
dct[keyword] = set(node_idxs)
return dct
super().__init__(object_hook=object_hook, *args, **kwargs)
# get dataset
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
@@ -1026,7 +1034,7 @@ class ExternalKnowledgeApis(Base):
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"tenant_id": self.tenant_id,
@@ -1039,14 +1047,14 @@ class ExternalKnowledgeApis(Base):
}
@property
def settings_dict(self):
def settings_dict(self) -> dict[str, Any] | None:
try:
return json.loads(self.settings) if self.settings else None
except JSONDecodeError:
return None
@property
def dataset_bindings(self):
def dataset_bindings(self) -> list[dict[str, Any]]:
external_knowledge_bindings = (
db.session.query(ExternalKnowledgeBindings)
.where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
@@ -1054,7 +1062,7 @@ class ExternalKnowledgeApis(Base):
)
dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
dataset_bindings = []
dataset_bindings: list[dict[str, Any]] = []
for dataset in datasets:
dataset_bindings.append({"id": dataset.id, "name": dataset.name})

View File

@@ -16,7 +16,7 @@ if TYPE_CHECKING:
import sqlalchemy as sa
from flask import request
from flask_login import UserMixin
from flask_login import UserMixin # type: ignore[import-untyped]
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column
@@ -24,7 +24,7 @@ from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from core.file import helpers as file_helpers
from libs.helper import generate_string
from libs.helper import generate_string # type: ignore[import-not-found]
from .account import Account, Tenant
from .base import Base
@@ -98,7 +98,7 @@ class App(Base):
use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
@property
def desc_or_prompt(self):
def desc_or_prompt(self) -> str:
if self.description:
return self.description
else:
@@ -109,12 +109,12 @@ class App(Base):
return ""
@property
def site(self):
def site(self) -> Optional["Site"]:
site = db.session.query(Site).where(Site.app_id == self.id).first()
return site
@property
def app_model_config(self):
def app_model_config(self) -> Optional["AppModelConfig"]:
if self.app_model_config_id:
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
@@ -130,11 +130,11 @@ class App(Base):
return None
@property
def api_base_url(self):
def api_base_url(self) -> str:
return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"
@property
def tenant(self):
def tenant(self) -> Optional[Tenant]:
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant
@@ -162,7 +162,7 @@ class App(Base):
return str(self.mode)
@property
def deleted_tools(self):
def deleted_tools(self) -> list[dict[str, str]]:
from core.tools.tool_manager import ToolManager
from services.plugin.plugin_service import PluginService
@@ -242,7 +242,7 @@ class App(Base):
provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids)
}
deleted_tools = []
deleted_tools: list[dict[str, str]] = []
for tool in tools:
keys = list(tool.keys())
@@ -275,7 +275,7 @@ class App(Base):
return deleted_tools
@property
def tags(self):
def tags(self) -> list["Tag"]:
tags = (
db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
@@ -291,7 +291,7 @@ class App(Base):
return tags or []
@property
def author_name(self):
def author_name(self) -> Optional[str]:
if self.created_by:
account = db.session.query(Account).where(Account.id == self.created_by).first()
if account:
@@ -334,20 +334,20 @@ class AppModelConfig(Base):
file_upload = mapped_column(sa.Text)
@property
def app(self):
def app(self) -> Optional[App]:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
@property
def model_dict(self):
def model_dict(self) -> dict[str, Any]:
return json.loads(self.model) if self.model else {}
@property
def suggested_questions_list(self):
def suggested_questions_list(self) -> list[str]:
return json.loads(self.suggested_questions) if self.suggested_questions else []
@property
def suggested_questions_after_answer_dict(self):
def suggested_questions_after_answer_dict(self) -> dict[str, Any]:
return (
json.loads(self.suggested_questions_after_answer)
if self.suggested_questions_after_answer
@@ -355,19 +355,19 @@ class AppModelConfig(Base):
)
@property
def speech_to_text_dict(self):
def speech_to_text_dict(self) -> dict[str, Any]:
return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False}
@property
def text_to_speech_dict(self):
def text_to_speech_dict(self) -> dict[str, Any]:
return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False}
@property
def retriever_resource_dict(self):
def retriever_resource_dict(self) -> dict[str, Any]:
return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
@property
def annotation_reply_dict(self):
def annotation_reply_dict(self) -> dict[str, Any]:
annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
)
@@ -390,11 +390,11 @@ class AppModelConfig(Base):
return {"enabled": False}
@property
def more_like_this_dict(self):
def more_like_this_dict(self) -> dict[str, Any]:
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
@property
def sensitive_word_avoidance_dict(self):
def sensitive_word_avoidance_dict(self) -> dict[str, Any]:
return (
json.loads(self.sensitive_word_avoidance)
if self.sensitive_word_avoidance
@@ -402,15 +402,15 @@ class AppModelConfig(Base):
)
@property
def external_data_tools_list(self) -> list[dict]:
def external_data_tools_list(self) -> list[dict[str, Any]]:
return json.loads(self.external_data_tools) if self.external_data_tools else []
@property
def user_input_form_list(self):
def user_input_form_list(self) -> list[dict[str, Any]]:
return json.loads(self.user_input_form) if self.user_input_form else []
@property
def agent_mode_dict(self):
def agent_mode_dict(self) -> dict[str, Any]:
return (
json.loads(self.agent_mode)
if self.agent_mode
@@ -418,17 +418,17 @@ class AppModelConfig(Base):
)
@property
def chat_prompt_config_dict(self):
def chat_prompt_config_dict(self) -> dict[str, Any]:
return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}
@property
def completion_prompt_config_dict(self):
def completion_prompt_config_dict(self) -> dict[str, Any]:
return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}
@property
def dataset_configs_dict(self):
def dataset_configs_dict(self) -> dict[str, Any]:
if self.dataset_configs:
dataset_configs: dict = json.loads(self.dataset_configs)
dataset_configs: dict[str, Any] = json.loads(self.dataset_configs)
if "retrieval_model" not in dataset_configs:
return {"retrieval_model": "single"}
else:
@@ -438,7 +438,7 @@ class AppModelConfig(Base):
}
@property
def file_upload_dict(self):
def file_upload_dict(self) -> dict[str, Any]:
return (
json.loads(self.file_upload)
if self.file_upload
@@ -452,7 +452,7 @@ class AppModelConfig(Base):
}
)
def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return {
"opening_statement": self.opening_statement,
"suggested_questions": self.suggested_questions_list,
@@ -546,7 +546,7 @@ class RecommendedApp(Base):
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def app(self):
def app(self) -> Optional[App]:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
@@ -570,12 +570,12 @@ class InstalledApp(Base):
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def app(self):
def app(self) -> Optional[App]:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
@property
def tenant(self):
def tenant(self) -> Optional[Tenant]:
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant
@@ -622,7 +622,7 @@ class Conversation(Base):
mode: Mapped[str] = mapped_column(String(255))
name: Mapped[str] = mapped_column(String(255), nullable=False)
summary = mapped_column(sa.Text)
_inputs: Mapped[dict] = mapped_column("inputs", sa.JSON)
_inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON)
introduction = mapped_column(sa.Text)
system_instruction = mapped_column(sa.Text)
system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
@@ -652,7 +652,7 @@ class Conversation(Base):
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
@property
def inputs(self):
def inputs(self) -> dict[str, Any]:
inputs = self._inputs.copy()
# Convert file mapping to File object
@@ -660,22 +660,39 @@ class Conversation(Base):
# NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now.
from factories import file_factory
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
if value["transfer_method"] == FileTransferMethod.TOOL_FILE:
value["tool_file_id"] = value["related_id"]
elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
value["upload_file_id"] = value["related_id"]
inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"])
elif isinstance(value, list) and all(
isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
if (
isinstance(value, dict)
and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY
):
inputs[key] = []
for item in value:
if item["transfer_method"] == FileTransferMethod.TOOL_FILE:
item["tool_file_id"] = item["related_id"]
elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
item["upload_file_id"] = item["related_id"]
inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"]))
value_dict = cast(dict[str, Any], value)
if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
value_dict["tool_file_id"] = value_dict["related_id"]
elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
value_dict["upload_file_id"] = value_dict["related_id"]
tenant_id = cast(str, value_dict.get("tenant_id", ""))
inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id)
elif isinstance(value, list):
value_list = cast(list[Any], value)
if all(
isinstance(item, dict)
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
for item in value_list
):
file_list: list[File] = []
for item in value_list:
if not isinstance(item, dict):
continue
item_dict = cast(dict[str, Any], item)
if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
item_dict["tool_file_id"] = item_dict["related_id"]
elif item_dict["transfer_method"] in [
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.REMOTE_URL,
]:
item_dict["upload_file_id"] = item_dict["related_id"]
tenant_id = cast(str, item_dict.get("tenant_id", ""))
file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id))
inputs[key] = file_list
return inputs
@@ -685,8 +702,10 @@ class Conversation(Base):
for k, v in inputs.items():
if isinstance(v, File):
inputs[k] = v.model_dump()
elif isinstance(v, list) and all(isinstance(item, File) for item in v):
inputs[k] = [item.model_dump() for item in v]
elif isinstance(v, list):
v_list = cast(list[Any], v)
if all(isinstance(item, File) for item in v_list):
inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)]
self._inputs = inputs
@property
@@ -826,7 +845,7 @@ class Conversation(Base):
)
@property
def app(self):
def app(self) -> Optional[App]:
return db.session.query(App).where(App.id == self.app_id).first()
@property
@@ -839,7 +858,7 @@ class Conversation(Base):
return None
@property
def from_account_name(self):
def from_account_name(self) -> Optional[str]:
if self.from_account_id:
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
if account:
@@ -848,10 +867,10 @@ class Conversation(Base):
return None
@property
def in_debug_mode(self):
def in_debug_mode(self) -> bool:
return self.override_model_configs is not None
def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"app_id": self.app_id,
@@ -897,7 +916,7 @@ class Message(Base):
model_id = mapped_column(String(255), nullable=True)
override_model_configs = mapped_column(sa.Text)
conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False)
_inputs: Mapped[dict] = mapped_column("inputs", sa.JSON)
_inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON)
query: Mapped[str] = mapped_column(sa.Text, nullable=False)
message = mapped_column(sa.JSON, nullable=False)
message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
@@ -924,28 +943,45 @@ class Message(Base):
workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID)
@property
def inputs(self):
def inputs(self) -> dict[str, Any]:
inputs = self._inputs.copy()
for key, value in inputs.items():
# NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now.
from factories import file_factory
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
if value["transfer_method"] == FileTransferMethod.TOOL_FILE:
value["tool_file_id"] = value["related_id"]
elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
value["upload_file_id"] = value["related_id"]
inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"])
elif isinstance(value, list) and all(
isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
if (
isinstance(value, dict)
and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY
):
inputs[key] = []
for item in value:
if item["transfer_method"] == FileTransferMethod.TOOL_FILE:
item["tool_file_id"] = item["related_id"]
elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
item["upload_file_id"] = item["related_id"]
inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"]))
value_dict = cast(dict[str, Any], value)
if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
value_dict["tool_file_id"] = value_dict["related_id"]
elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
value_dict["upload_file_id"] = value_dict["related_id"]
tenant_id = cast(str, value_dict.get("tenant_id", ""))
inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id)
elif isinstance(value, list):
value_list = cast(list[Any], value)
if all(
isinstance(item, dict)
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
for item in value_list
):
file_list: list[File] = []
for item in value_list:
if not isinstance(item, dict):
continue
item_dict = cast(dict[str, Any], item)
if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
item_dict["tool_file_id"] = item_dict["related_id"]
elif item_dict["transfer_method"] in [
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.REMOTE_URL,
]:
item_dict["upload_file_id"] = item_dict["related_id"]
tenant_id = cast(str, item_dict.get("tenant_id", ""))
file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id))
inputs[key] = file_list
return inputs
@inputs.setter
@@ -954,8 +990,10 @@ class Message(Base):
for k, v in inputs.items():
if isinstance(v, File):
inputs[k] = v.model_dump()
elif isinstance(v, list) and all(isinstance(item, File) for item in v):
inputs[k] = [item.model_dump() for item in v]
elif isinstance(v, list):
v_list = cast(list[Any], v)
if all(isinstance(item, File) for item in v_list):
inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)]
self._inputs = inputs
@property
@@ -1083,15 +1121,15 @@ class Message(Base):
return None
@property
def in_debug_mode(self):
def in_debug_mode(self) -> bool:
return self.override_model_configs is not None
@property
def message_metadata_dict(self):
def message_metadata_dict(self) -> dict[str, Any]:
return json.loads(self.message_metadata) if self.message_metadata else {}
@property
def agent_thoughts(self):
def agent_thoughts(self) -> list["MessageAgentThought"]:
return (
db.session.query(MessageAgentThought)
.where(MessageAgentThought.message_id == self.id)
@@ -1100,11 +1138,11 @@ class Message(Base):
)
@property
def retriever_resources(self):
def retriever_resources(self) -> Any | list[Any]:
return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []
@property
def message_files(self):
def message_files(self) -> list[dict[str, Any]]:
from factories import file_factory
message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all()
@@ -1112,7 +1150,7 @@ class Message(Base):
if not current_app:
raise ValueError(f"App {self.app_id} not found")
files = []
files: list[File] = []
for message_file in message_files:
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE.value:
if message_file.upload_file_id is None:
@@ -1159,7 +1197,7 @@ class Message(Base):
)
files.append(file)
result = [
result: list[dict[str, Any]] = [
{"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()}
for (file, message_file) in zip(files, message_files)
]
@@ -1176,7 +1214,7 @@ class Message(Base):
return None
def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"app_id": self.app_id,
@@ -1200,7 +1238,7 @@ class Message(Base):
}
@classmethod
def from_dict(cls, data: dict):
def from_dict(cls, data: dict[str, Any]) -> "Message":
return cls(
id=data["id"],
app_id=data["app_id"],
@@ -1250,7 +1288,7 @@ class MessageFeedback(Base):
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
return account
def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return {
"id": str(self.id),
"app_id": str(self.app_id),
@@ -1435,7 +1473,18 @@ class EndUser(Base, UserMixin):
type: Mapped[str] = mapped_column(String(255), nullable=False)
external_user_id = mapped_column(String(255), nullable=True)
name = mapped_column(String(255))
is_anonymous: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
_is_anonymous: Mapped[bool] = mapped_column(
"is_anonymous", sa.Boolean, nullable=False, server_default=sa.text("true")
)
@property
def is_anonymous(self) -> Literal[False]:
return False
@is_anonymous.setter
def is_anonymous(self, value: bool) -> None:
self._is_anonymous = value
session_id: Mapped[str] = mapped_column()
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@@ -1461,7 +1510,7 @@ class AppMCPServer(Base):
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@staticmethod
def generate_server_code(n):
def generate_server_code(n: int) -> str:
while True:
result = generate_string(n)
while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0:
@@ -1518,7 +1567,7 @@ class Site(Base):
self._custom_disclaimer = value
@staticmethod
def generate_code(n):
def generate_code(n: int) -> str:
while True:
result = generate_string(n)
while db.session.query(Site).where(Site.code == result).count() > 0:
@@ -1549,7 +1598,7 @@ class ApiToken(Base):
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@staticmethod
def generate_api_key(prefix, n):
def generate_api_key(prefix: str, n: int) -> str:
while True:
result = prefix + generate_string(n)
if db.session.scalar(select(exists().where(ApiToken.token == result))):
@@ -1689,7 +1738,7 @@ class MessageAgentThought(Base):
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
@property
def files(self):
def files(self) -> list[Any]:
if self.message_files:
return cast(list[Any], json.loads(self.message_files))
else:
@@ -1700,32 +1749,32 @@ class MessageAgentThought(Base):
return self.tool.split(";") if self.tool else []
@property
def tool_labels(self):
def tool_labels(self) -> dict[str, Any]:
try:
if self.tool_labels_str:
return cast(dict, json.loads(self.tool_labels_str))
return cast(dict[str, Any], json.loads(self.tool_labels_str))
else:
return {}
except Exception:
return {}
@property
def tool_meta(self):
def tool_meta(self) -> dict[str, Any]:
try:
if self.tool_meta_str:
return cast(dict, json.loads(self.tool_meta_str))
return cast(dict[str, Any], json.loads(self.tool_meta_str))
else:
return {}
except Exception:
return {}
@property
def tool_inputs_dict(self):
def tool_inputs_dict(self) -> dict[str, Any]:
tools = self.tools
try:
if self.tool_input:
data = json.loads(self.tool_input)
result = {}
result: dict[str, Any] = {}
for tool in tools:
if tool in data:
result[tool] = data[tool]
@@ -1741,12 +1790,12 @@ class MessageAgentThought(Base):
return {}
@property
def tool_outputs_dict(self):
def tool_outputs_dict(self) -> dict[str, Any]:
tools = self.tools
try:
if self.observation:
data = json.loads(self.observation)
result = {}
result: dict[str, Any] = {}
for tool in tools:
if tool in data:
result[tool] = data[tool]
@@ -1844,14 +1893,14 @@ class TraceAppConfig(Base):
is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
@property
def tracing_config_dict(self):
def tracing_config_dict(self) -> dict[str, Any]:
return self.tracing_config or {}
@property
def tracing_config_str(self):
def tracing_config_str(self) -> str:
return json.dumps(self.tracing_config_dict)
def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"app_id": self.app_id,

View File

@@ -17,7 +17,7 @@ class ProviderType(Enum):
SYSTEM = "system"
@staticmethod
def value_of(value):
def value_of(value: str) -> "ProviderType":
for member in ProviderType:
if member.value == value:
return member
@@ -35,7 +35,7 @@ class ProviderQuotaType(Enum):
"""hosted trial quota"""
@staticmethod
def value_of(value):
def value_of(value: str) -> "ProviderQuotaType":
for member in ProviderQuotaType:
if member.value == value:
return member

View File

@@ -1,6 +1,6 @@
import json
from datetime import datetime
from typing import Optional, cast
from typing import Any, Optional, cast
from urllib.parse import urlparse
import sqlalchemy as sa
@@ -54,8 +54,8 @@ class ToolOAuthTenantClient(Base):
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
@property
def oauth_params(self):
return cast(dict, json.loads(self.encrypted_oauth_params or "{}"))
def oauth_params(self) -> dict[str, Any]:
return cast(dict[str, Any], json.loads(self.encrypted_oauth_params or "{}"))
class BuiltinToolProvider(Base):
@@ -96,8 +96,8 @@ class BuiltinToolProvider(Base):
expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1"))
@property
def credentials(self):
return cast(dict, json.loads(self.encrypted_credentials))
def credentials(self) -> dict[str, Any]:
return cast(dict[str, Any], json.loads(self.encrypted_credentials))
class ApiToolProvider(Base):
@@ -146,8 +146,8 @@ class ApiToolProvider(Base):
return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)]
@property
def credentials(self):
return dict(json.loads(self.credentials_str))
def credentials(self) -> dict[str, Any]:
return dict[str, Any](json.loads(self.credentials_str))
@property
def user(self) -> Account | None:
@@ -289,9 +289,9 @@ class MCPToolProvider(Base):
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
@property
def credentials(self):
def credentials(self) -> dict[str, Any]:
try:
return cast(dict, json.loads(self.encrypted_credentials)) or {}
return cast(dict[str, Any], json.loads(self.encrypted_credentials)) or {}
except Exception:
return {}
@@ -327,12 +327,12 @@ class MCPToolProvider(Base):
return mask_url(self.decrypted_server_url)
@property
def decrypted_credentials(self):
def decrypted_credentials(self) -> dict[str, Any]:
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.encryption import create_provider_encrypter
provider_controller = MCPToolProviderController._from_db(self)
provider_controller = MCPToolProviderController.from_db(self)
encrypter, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
@@ -340,7 +340,7 @@ class MCPToolProvider(Base):
cache=NoOpProviderCredentialCache(),
)
return encrypter.decrypt(self.credentials) # type: ignore
return encrypter.decrypt(self.credentials)
class ToolModelInvoke(Base):

View File

@@ -1,29 +1,34 @@
import enum
from typing import Generic, TypeVar
import uuid
from typing import Any, Generic, TypeVar
from sqlalchemy import CHAR, VARCHAR, TypeDecorator
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql.type_api import TypeEngine
class StringUUID(TypeDecorator):
class StringUUID(TypeDecorator[uuid.UUID | str | None]):
impl = CHAR
cache_ok = True
def process_bind_param(self, value, dialect):
def process_bind_param(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None:
if value is None:
return value
elif dialect.name == "postgresql":
return str(value)
else:
return value.hex
if isinstance(value, uuid.UUID):
return value.hex
return value
def load_dialect_impl(self, dialect):
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
if dialect.name == "postgresql":
return dialect.type_descriptor(UUID())
else:
return dialect.type_descriptor(CHAR(36))
def process_result_value(self, value, dialect):
def process_result_value(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None:
if value is None:
return value
return str(value)
@@ -32,7 +37,7 @@ class StringUUID(TypeDecorator):
_E = TypeVar("_E", bound=enum.StrEnum)
class EnumText(TypeDecorator, Generic[_E]):
class EnumText(TypeDecorator[_E | None], Generic[_E]):
impl = VARCHAR
cache_ok = True
@@ -50,28 +55,25 @@ class EnumText(TypeDecorator, Generic[_E]):
# leave some rooms for future longer enum values.
self._length = max(max_enum_value_len, 20)
def process_bind_param(self, value: _E | str | None, dialect):
def process_bind_param(self, value: _E | str | None, dialect: Dialect) -> str | None:
if value is None:
return value
if isinstance(value, self._enum_class):
return value.value
elif isinstance(value, str):
self._enum_class(value)
return value
else:
raise TypeError(f"expected str or {self._enum_class}, got {type(value)}")
# Since _E is bound to StrEnum which inherits from str, at this point value must be str
self._enum_class(value)
return value
def load_dialect_impl(self, dialect):
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
return dialect.type_descriptor(VARCHAR(self._length))
def process_result_value(self, value, dialect) -> _E | None:
def process_result_value(self, value: str | None, dialect: Dialect) -> _E | None:
if value is None:
return value
if not isinstance(value, str):
raise TypeError(f"expected str, got {type(value)}")
# Type annotation guarantees value is str at this point
return self._enum_class(value)
def compare_values(self, x, y):
def compare_values(self, x: _E | None, y: _E | None) -> bool:
if x is None or y is None:
return x is y
return x == y

View File

@@ -3,7 +3,7 @@ import logging
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from uuid import uuid4
import sqlalchemy as sa
@@ -224,7 +224,7 @@ class Workflow(Base):
raise WorkflowDataError("nodes not found in workflow graph")
try:
node_config = next(filter(lambda node: node["id"] == node_id, nodes))
node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes))
except StopIteration:
raise NodeNotFoundError(node_id)
assert isinstance(node_config, dict)
@@ -289,7 +289,7 @@ class Workflow(Base):
def features_dict(self) -> dict[str, Any]:
return json.loads(self.features) if self.features else {}
def user_input_form(self, to_old_structure: bool = False):
def user_input_form(self, to_old_structure: bool = False) -> list[Any]:
# get start node from graph
if not self.graph:
return []
@@ -306,7 +306,7 @@ class Workflow(Base):
variables: list[Any] = start_node.get("data", {}).get("variables", [])
if to_old_structure:
old_structure_variables = []
old_structure_variables: list[dict[str, Any]] = []
for variable in variables:
old_structure_variables.append({variable["type"]: variable})
@@ -346,9 +346,7 @@ class Workflow(Base):
@property
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
# TODO: find some way to init `self._environment_variables` when instance created.
if self._environment_variables is None:
self._environment_variables = "{}"
# _environment_variables is guaranteed to be non-None due to server_default="{}"
# Use workflow.tenant_id to avoid relying on request user in background threads
tenant_id = self.tenant_id
@@ -362,17 +360,18 @@ class Workflow(Base):
]
# decrypt secret variables value
def decrypt_func(var):
def decrypt_func(var: Variable) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)):
return var
else:
raise AssertionError("this statement should be unreachable.")
# Other variable types are not supported for environment variables
raise AssertionError(f"Unexpected variable type for environment variable: {type(var)}")
decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list(
map(decrypt_func, results)
)
decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = [
decrypt_func(var) for var in results
]
return decrypted_results
@environment_variables.setter
@@ -400,7 +399,7 @@ class Workflow(Base):
value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name})
# encrypt secret variables value
def encrypt_func(var):
def encrypt_func(var: Variable) -> Variable:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)})
else:
@@ -430,9 +429,7 @@ class Workflow(Base):
@property
def conversation_variables(self) -> Sequence[Variable]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = "{}"
# _conversation_variables is guaranteed to be non-None due to server_default="{}"
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
@@ -577,7 +574,7 @@ class WorkflowRun(Base):
}
@classmethod
def from_dict(cls, data: dict) -> "WorkflowRun":
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
return cls(
id=data.get("id"),
tenant_id=data.get("tenant_id"),
@@ -662,7 +659,8 @@ class WorkflowNodeExecutionModel(Base):
__tablename__ = "workflow_node_executions"
@declared_attr
def __table_args__(cls): # noqa
@classmethod
def __table_args__(cls) -> Any:
return (
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
Index(
@@ -699,7 +697,7 @@ class WorkflowNodeExecutionModel(Base):
# MyPy may flag the following line because it doesn't recognize that
# the `declared_attr` decorator passes the receiving class as the first
# argument to this method, allowing us to reference class attributes.
cls.created_at.desc(), # type: ignore
cls.created_at.desc(),
),
)
@@ -761,15 +759,15 @@ class WorkflowNodeExecutionModel(Base):
return json.loads(self.execution_metadata) if self.execution_metadata else {}
@property
def extras(self):
def extras(self) -> dict[str, Any]:
from core.tools.tool_manager import ToolManager
extras = {}
extras: dict[str, Any] = {}
if self.execution_metadata_dict:
from core.workflow.nodes import NodeType
if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict:
tool_info = self.execution_metadata_dict["tool_info"]
tool_info: dict[str, Any] = self.execution_metadata_dict["tool_info"]
extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self.tenant_id,
provider_type=tool_info["provider_type"],
@@ -1037,7 +1035,7 @@ class WorkflowDraftVariable(Base):
# making this attribute harder to access from outside the class.
__value: Segment | None
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""
The constructor of `WorkflowDraftVariable` is not intended for
direct use outside this file. Its solo purpose is setup private state
@@ -1055,15 +1053,15 @@ class WorkflowDraftVariable(Base):
self.__value = None
def get_selector(self) -> list[str]:
selector = json.loads(self.selector)
selector: Any = json.loads(self.selector)
if not isinstance(selector, list):
logger.error(
"invalid selector loaded from database, type=%s, value=%s",
type(selector),
type(selector).__name__,
self.selector,
)
raise ValueError("invalid selector.")
return selector
return cast(list[str], selector)
def _set_selector(self, value: list[str]):
self.selector = json.dumps(value)
@@ -1086,15 +1084,17 @@ class WorkflowDraftVariable(Base):
# `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging.
if isinstance(value, dict):
if not maybe_file_object(value):
return value
return cast(Any, value)
return File.model_validate(value)
elif isinstance(value, list) and value:
first = value[0]
value_list = cast(list[Any], value)
first: Any = value_list[0]
if not maybe_file_object(first):
return value
return [File.model_validate(i) for i in value]
return cast(Any, value)
file_list: list[File] = [File.model_validate(cast(dict[str, Any], i)) for i in value_list]
return cast(Any, file_list)
else:
return value
return cast(Any, value)
@classmethod
def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment: