From 9b8a03b53b1163ffeffc6646ad827a375b498d77 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 8 Sep 2025 09:42:27 +0800 Subject: [PATCH] [Chore/Refactor] Improve type annotations in models module (#25281) Signed-off-by: -LAN- Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> --- api/controllers/console/apikey.py | 2 +- .../console/datasets/datasets_document.py | 6 + api/controllers/console/explore/parameter.py | 2 + api/controllers/console/explore/workflow.py | 4 + api/core/app/apps/completion/app_generator.py | 3 + api/core/rag/extractor/notion_extractor.py | 3 +- api/core/tools/mcp_tool/provider.py | 4 +- api/core/tools/tool_manager.py | 4 +- api/models/account.py | 8 +- api/models/dataset.py | 134 +++++----- api/models/model.py | 251 +++++++++++------- api/models/provider.py | 4 +- api/models/tools.py | 24 +- api/models/types.py | 38 +-- api/models/workflow.py | 62 ++--- api/pyrightconfig.json | 1 - api/services/agent_service.py | 4 +- api/services/app_service.py | 5 +- api/services/audio_service.py | 6 +- api/services/dataset_service.py | 7 +- api/services/external_knowledge_service.py | 5 +- .../tools/mcp_tools_manage_service.py | 2 +- .../unit_tests/models/test_types_enum_text.py | 4 +- 23 files changed, 332 insertions(+), 251 deletions(-) diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 758b574d1..cfd5f73ad 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -87,7 +87,7 @@ class BaseApiKeyListResource(Resource): custom="max_keys_exceeded", ) - key = ApiToken.generate_api_key(self.token_prefix, 24) + key = ApiToken.generate_api_key(self.token_prefix or "", 24) api_token = ApiToken() setattr(api_token, self.resource_id_field, resource_id) api_token.tenant_id = current_user.current_tenant_id diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index f9703f5a2..c9c0b6a5c 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -475,6 +475,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): data_source_info = document.data_source_info_dict if document.data_source_type == "upload_file": + if not data_source_info: + continue file_id = data_source_info["upload_file_id"] file_detail = ( db.session.query(UploadFile) @@ -491,6 +493,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): extract_settings.append(extract_setting) elif document.data_source_type == "notion_import": + if not data_source_info: + continue extract_setting = ExtractSetting( datasource_type=DatasourceType.NOTION.value, notion_info={ @@ -503,6 +507,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): ) extract_settings.append(extract_setting) elif document.data_source_type == "website_crawl": + if not data_source_info: + continue extract_setting = ExtractSetting( datasource_type=DatasourceType.WEBSITE.value, website_info={ diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index c36874475..d9afb5bab 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -43,6 +43,8 @@ class ExploreAppMetaApi(InstalledAppResource): def get(self, installed_app: InstalledApp): """Get app meta""" app_model = installed_app.app + if not app_model: + raise ValueError("App not found") return AppService().get_app_meta(app_model) diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 0a5a88d6f..d80bfcfab 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -35,6 +35,8 @@ class InstalledAppWorkflowRunApi(InstalledAppResource): Run workflow """ app_model = installed_app.app + if not app_model: + raise NotWorkflowAppError() app_mode = AppMode.value_of(app_model.mode) if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() @@ -73,6 +75,8 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource): Stop workflow task """ app_model = installed_app.app + if not app_model: + raise NotWorkflowAppError() app_mode = AppMode.value_of(app_model.mode) if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 6e43e5ec9..8485ce751 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -262,6 +262,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator): raise MessageNotExistsError() current_app_model_config = app_model.app_model_config + if not current_app_model_config: + raise MoreLikeThisDisabledError() + more_like_this = current_app_model_config.more_like_this_dict if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 206b2bb92..fa96d73cf 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -334,7 +334,8 @@ class NotionExtractor(BaseExtractor): last_edited_time = self.get_notion_last_edited_time() data_source_info = document_model.data_source_info_dict - data_source_info["last_edited_time"] = last_edited_time + if data_source_info: + data_source_info["last_edited_time"] = last_edited_time db.session.query(DocumentModel).filter_by(id=document_model.id).update( {DocumentModel.data_source_info: json.dumps(data_source_info)} diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py index fa99cccb8..dd9d3a137 100644 --- a/api/core/tools/mcp_tool/provider.py +++ b/api/core/tools/mcp_tool/provider.py @@ -1,5 +1,5 @@ import json -from typing import Any, Optional +from typing import Any, Optional, Self from core.mcp.types import Tool as RemoteMCPTool from core.tools.__base.tool_provider import ToolProviderController @@ -48,7 +48,7 @@ class MCPToolProviderController(ToolProviderController): return ToolProviderType.MCP @classmethod - def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController": + def from_db(cls, db_provider: MCPToolProvider) -> Self: """ from db provider """ diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 834f58be6..00fc57a3f 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -773,7 +773,7 @@ class ToolManager: if provider is None: raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") - controller = MCPToolProviderController._from_db(provider) + controller = MCPToolProviderController.from_db(provider) return controller @@ -928,7 +928,7 @@ class ToolManager: tenant_id: str, provider_type: ToolProviderType, provider_id: str, - ) -> Union[str, dict]: + ) -> Union[str, dict[str, Any]]: """ get the tool icon diff --git a/api/models/account.py b/api/models/account.py index 4fec41c4e..019159d2d 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -1,10 +1,10 @@ import enum import json from datetime import datetime -from typing import Optional +from typing import Any, Optional import sqlalchemy as sa -from flask_login import UserMixin +from flask_login import UserMixin # type: ignore[import-untyped] from sqlalchemy import DateTime, String, func, select from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor @@ -225,11 +225,11 @@ class Tenant(Base): ) @property - def custom_config_dict(self): + def custom_config_dict(self) -> dict[str, Any]: return json.loads(self.custom_config) if self.custom_config else {} @custom_config_dict.setter - def custom_config_dict(self, value: dict): + def custom_config_dict(self, value: dict[str, Any]) -> None: self.custom_config = json.dumps(value) diff --git a/api/models/dataset.py b/api/models/dataset.py index 1d2cb410f..38b5c74de 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -286,7 +286,7 @@ class DatasetProcessRule(Base): "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, } - def to_dict(self): + def to_dict(self) -> dict[str, Any]: return { "id": self.id, "dataset_id": self.dataset_id, @@ -295,7 +295,7 @@ class DatasetProcessRule(Base): } @property - def rules_dict(self): + def rules_dict(self) -> dict[str, Any] | None: try: return json.loads(self.rules) if self.rules else None except JSONDecodeError: @@ -392,10 +392,10 @@ class Document(Base): return status @property - def data_source_info_dict(self): + def data_source_info_dict(self) -> dict[str, Any] | None: if self.data_source_info: try: - data_source_info_dict = json.loads(self.data_source_info) + data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info) except JSONDecodeError: data_source_info_dict = {} @@ -403,10 +403,10 @@ class Document(Base): return None @property - def data_source_detail_dict(self): + def data_source_detail_dict(self) -> dict[str, Any]: if self.data_source_info: if self.data_source_type == "upload_file": - data_source_info_dict = json.loads(self.data_source_info) + data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info) file_detail = ( db.session.query(UploadFile) .where(UploadFile.id == data_source_info_dict["upload_file_id"]) @@ -425,7 +425,8 @@ class Document(Base): } } elif self.data_source_type in {"notion_import", "website_crawl"}: - return json.loads(self.data_source_info) + result: dict[str, Any] = json.loads(self.data_source_info) + return result return {} @property @@ -471,7 +472,7 @@ class Document(Base): return self.updated_at @property - def doc_metadata_details(self): + def doc_metadata_details(self) -> list[dict[str, Any]] | None: if self.doc_metadata: document_metadatas = ( db.session.query(DatasetMetadata) @@ -481,9 +482,9 @@ class Document(Base): ) .all() ) - metadata_list = [] + metadata_list: list[dict[str, Any]] = [] for metadata in document_metadatas: - metadata_dict = { + metadata_dict: dict[str, Any] = { "id": metadata.id, "name": metadata.name, "type": metadata.type, @@ -497,13 +498,13 @@ class Document(Base): return None @property - def process_rule_dict(self): - if self.dataset_process_rule_id: + def process_rule_dict(self) -> dict[str, Any] | None: + if self.dataset_process_rule_id and self.dataset_process_rule: return self.dataset_process_rule.to_dict() return None - def get_built_in_fields(self): - built_in_fields = [] + def get_built_in_fields(self) -> list[dict[str, Any]]: + built_in_fields: list[dict[str, Any]] = [] built_in_fields.append( { "id": "built-in", @@ -546,7 +547,7 @@ class Document(Base): ) return built_in_fields - def to_dict(self): + def to_dict(self) -> dict[str, Any]: return { "id": self.id, "tenant_id": self.tenant_id, @@ -592,13 +593,13 @@ class Document(Base): "data_source_info_dict": self.data_source_info_dict, "average_segment_length": self.average_segment_length, "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, - "dataset": self.dataset.to_dict() if self.dataset else None, + "dataset": None, # Dataset class doesn't have a to_dict method "segment_count": self.segment_count, "hit_count": self.hit_count, } @classmethod - def from_dict(cls, data: dict): + def from_dict(cls, data: dict[str, Any]): return cls( id=data.get("id"), tenant_id=data.get("tenant_id"), @@ -711,46 +712,48 @@ class DocumentSegment(Base): ) @property - def child_chunks(self): - process_rule = self.document.dataset_process_rule - if process_rule.mode == "hierarchical": - rules = Rule(**process_rule.rules_dict) - if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: - child_chunks = ( - db.session.query(ChildChunk) - .where(ChildChunk.segment_id == self.id) - .order_by(ChildChunk.position.asc()) - .all() - ) - return child_chunks or [] - else: - return [] - else: + def child_chunks(self) -> list[Any]: + if not self.document: return [] + process_rule = self.document.dataset_process_rule + if process_rule and process_rule.mode == "hierarchical": + rules_dict = process_rule.rules_dict + if rules_dict: + rules = Rule(**rules_dict) + if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: + child_chunks = ( + db.session.query(ChildChunk) + .where(ChildChunk.segment_id == self.id) + .order_by(ChildChunk.position.asc()) + .all() + ) + return child_chunks or [] + return [] - def get_child_chunks(self): - process_rule = self.document.dataset_process_rule - if process_rule.mode == "hierarchical": - rules = Rule(**process_rule.rules_dict) - if rules.parent_mode: - child_chunks = ( - db.session.query(ChildChunk) - .where(ChildChunk.segment_id == self.id) - .order_by(ChildChunk.position.asc()) - .all() - ) - return child_chunks or [] - else: - return [] - else: + def get_child_chunks(self) -> list[Any]: + if not self.document: return [] + process_rule = self.document.dataset_process_rule + if process_rule and process_rule.mode == "hierarchical": + rules_dict = process_rule.rules_dict + if rules_dict: + rules = Rule(**rules_dict) + if rules.parent_mode: + child_chunks = ( + db.session.query(ChildChunk) + .where(ChildChunk.segment_id == self.id) + .order_by(ChildChunk.position.asc()) + .all() + ) + return child_chunks or [] + return [] @property - def sign_content(self): + def sign_content(self) -> str: return self.get_sign_content() - def get_sign_content(self): - signed_urls = [] + def get_sign_content(self) -> str: + signed_urls: list[tuple[int, int, str]] = [] text = self.content # For data before v0.10.0 @@ -890,17 +893,22 @@ class DatasetKeywordTable(Base): ) @property - def keyword_table_dict(self): + def keyword_table_dict(self) -> dict[str, set[Any]] | None: class SetDecoder(json.JSONDecoder): - def __init__(self, *args, **kwargs): - super().__init__(object_hook=self.object_hook, *args, **kwargs) + def __init__(self, *args: Any, **kwargs: Any) -> None: + def object_hook(dct: Any) -> Any: + if isinstance(dct, dict): + result: dict[str, Any] = {} + items = cast(dict[str, Any], dct).items() + for keyword, node_idxs in items: + if isinstance(node_idxs, list): + result[keyword] = set(cast(list[Any], node_idxs)) + else: + result[keyword] = node_idxs + return result + return dct - def object_hook(self, dct): - if isinstance(dct, dict): - for keyword, node_idxs in dct.items(): - if isinstance(node_idxs, list): - dct[keyword] = set(node_idxs) - return dct + super().__init__(object_hook=object_hook, *args, **kwargs) # get dataset dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first() @@ -1026,7 +1034,7 @@ class ExternalKnowledgeApis(Base): updated_by = mapped_column(StringUUID, nullable=True) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - def to_dict(self): + def to_dict(self) -> dict[str, Any]: return { "id": self.id, "tenant_id": self.tenant_id, @@ -1039,14 +1047,14 @@ class ExternalKnowledgeApis(Base): } @property - def settings_dict(self): + def settings_dict(self) -> dict[str, Any] | None: try: return json.loads(self.settings) if self.settings else None except JSONDecodeError: return None @property - def dataset_bindings(self): + def dataset_bindings(self) -> list[dict[str, Any]]: external_knowledge_bindings = ( db.session.query(ExternalKnowledgeBindings) .where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) @@ -1054,7 +1062,7 @@ class ExternalKnowledgeApis(Base): ) dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all() - dataset_bindings = [] + dataset_bindings: list[dict[str, Any]] = [] for dataset in datasets: dataset_bindings.append({"id": dataset.id, "name": dataset.name}) diff --git a/api/models/model.py b/api/models/model.py index fbebdc817..f8ead1f87 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: import sqlalchemy as sa from flask import request -from flask_login import UserMixin +from flask_login import UserMixin # type: ignore[import-untyped] from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy.orm import Mapped, Session, mapped_column @@ -24,7 +24,7 @@ from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from core.file import helpers as file_helpers -from libs.helper import generate_string +from libs.helper import generate_string # type: ignore[import-not-found] from .account import Account, Tenant from .base import Base @@ -98,7 +98,7 @@ class App(Base): use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) @property - def desc_or_prompt(self): + def desc_or_prompt(self) -> str: if self.description: return self.description else: @@ -109,12 +109,12 @@ class App(Base): return "" @property - def site(self): + def site(self) -> Optional["Site"]: site = db.session.query(Site).where(Site.app_id == self.id).first() return site @property - def app_model_config(self): + def app_model_config(self) -> Optional["AppModelConfig"]: if self.app_model_config_id: return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() @@ -130,11 +130,11 @@ class App(Base): return None @property - def api_base_url(self): + def api_base_url(self) -> str: return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1" @property - def tenant(self): + def tenant(self) -> Optional[Tenant]: tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() return tenant @@ -162,7 +162,7 @@ class App(Base): return str(self.mode) @property - def deleted_tools(self): + def deleted_tools(self) -> list[dict[str, str]]: from core.tools.tool_manager import ToolManager from services.plugin.plugin_service import PluginService @@ -242,7 +242,7 @@ class App(Base): provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids) } - deleted_tools = [] + deleted_tools: list[dict[str, str]] = [] for tool in tools: keys = list(tool.keys()) @@ -275,7 +275,7 @@ class App(Base): return deleted_tools @property - def tags(self): + def tags(self) -> list["Tag"]: tags = ( db.session.query(Tag) .join(TagBinding, Tag.id == TagBinding.tag_id) @@ -291,7 +291,7 @@ class App(Base): return tags or [] @property - def author_name(self): + def author_name(self) -> Optional[str]: if self.created_by: account = db.session.query(Account).where(Account.id == self.created_by).first() if account: @@ -334,20 +334,20 @@ class AppModelConfig(Base): file_upload = mapped_column(sa.Text) @property - def app(self): + def app(self) -> Optional[App]: app = db.session.query(App).where(App.id == self.app_id).first() return app @property - def model_dict(self): + def model_dict(self) -> dict[str, Any]: return json.loads(self.model) if self.model else {} @property - def suggested_questions_list(self): + def suggested_questions_list(self) -> list[str]: return json.loads(self.suggested_questions) if self.suggested_questions else [] @property - def suggested_questions_after_answer_dict(self): + def suggested_questions_after_answer_dict(self) -> dict[str, Any]: return ( json.loads(self.suggested_questions_after_answer) if self.suggested_questions_after_answer @@ -355,19 +355,19 @@ class AppModelConfig(Base): ) @property - def speech_to_text_dict(self): + def speech_to_text_dict(self) -> dict[str, Any]: return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False} @property - def text_to_speech_dict(self): + def text_to_speech_dict(self) -> dict[str, Any]: return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False} @property - def retriever_resource_dict(self): + def retriever_resource_dict(self) -> dict[str, Any]: return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} @property - def annotation_reply_dict(self): + def annotation_reply_dict(self) -> dict[str, Any]: annotation_setting = ( db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first() ) @@ -390,11 +390,11 @@ class AppModelConfig(Base): return {"enabled": False} @property - def more_like_this_dict(self): + def more_like_this_dict(self) -> dict[str, Any]: return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} @property - def sensitive_word_avoidance_dict(self): + def sensitive_word_avoidance_dict(self) -> dict[str, Any]: return ( json.loads(self.sensitive_word_avoidance) if self.sensitive_word_avoidance @@ -402,15 +402,15 @@ class AppModelConfig(Base): ) @property - def external_data_tools_list(self) -> list[dict]: + def external_data_tools_list(self) -> list[dict[str, Any]]: return json.loads(self.external_data_tools) if self.external_data_tools else [] @property - def user_input_form_list(self): + def user_input_form_list(self) -> list[dict[str, Any]]: return json.loads(self.user_input_form) if self.user_input_form else [] @property - def agent_mode_dict(self): + def agent_mode_dict(self) -> dict[str, Any]: return ( json.loads(self.agent_mode) if self.agent_mode @@ -418,17 +418,17 @@ class AppModelConfig(Base): ) @property - def chat_prompt_config_dict(self): + def chat_prompt_config_dict(self) -> dict[str, Any]: return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {} @property - def completion_prompt_config_dict(self): + def completion_prompt_config_dict(self) -> dict[str, Any]: return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {} @property - def dataset_configs_dict(self): + def dataset_configs_dict(self) -> dict[str, Any]: if self.dataset_configs: - dataset_configs: dict = json.loads(self.dataset_configs) + dataset_configs: dict[str, Any] = json.loads(self.dataset_configs) if "retrieval_model" not in dataset_configs: return {"retrieval_model": "single"} else: @@ -438,7 +438,7 @@ class AppModelConfig(Base): } @property - def file_upload_dict(self): + def file_upload_dict(self) -> dict[str, Any]: return ( json.loads(self.file_upload) if self.file_upload @@ -452,7 +452,7 @@ class AppModelConfig(Base): } ) - def to_dict(self): + def to_dict(self) -> dict[str, Any]: return { "opening_statement": self.opening_statement, "suggested_questions": self.suggested_questions_list, @@ -546,7 +546,7 @@ class RecommendedApp(Base): updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property - def app(self): + def app(self) -> Optional[App]: app = db.session.query(App).where(App.id == self.app_id).first() return app @@ -570,12 +570,12 @@ class InstalledApp(Base): created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property - def app(self): + def app(self) -> Optional[App]: app = db.session.query(App).where(App.id == self.app_id).first() return app @property - def tenant(self): + def tenant(self) -> Optional[Tenant]: tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() return tenant @@ -622,7 +622,7 @@ class Conversation(Base): mode: Mapped[str] = mapped_column(String(255)) name: Mapped[str] = mapped_column(String(255), nullable=False) summary = mapped_column(sa.Text) - _inputs: Mapped[dict] = mapped_column("inputs", sa.JSON) + _inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON) introduction = mapped_column(sa.Text) system_instruction = mapped_column(sa.Text) system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) @@ -652,7 +652,7 @@ class Conversation(Base): is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) @property - def inputs(self): + def inputs(self) -> dict[str, Any]: inputs = self._inputs.copy() # Convert file mapping to File object @@ -660,22 +660,39 @@ class Conversation(Base): # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. from factories import file_factory - if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: - if value["transfer_method"] == FileTransferMethod.TOOL_FILE: - value["tool_file_id"] = value["related_id"] - elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: - value["upload_file_id"] = value["related_id"] - inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"]) - elif isinstance(value, list) and all( - isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value + if ( + isinstance(value, dict) + and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY ): - inputs[key] = [] - for item in value: - if item["transfer_method"] == FileTransferMethod.TOOL_FILE: - item["tool_file_id"] = item["related_id"] - elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: - item["upload_file_id"] = item["related_id"] - inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) + value_dict = cast(dict[str, Any], value) + if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: + value_dict["tool_file_id"] = value_dict["related_id"] + elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: + value_dict["upload_file_id"] = value_dict["related_id"] + tenant_id = cast(str, value_dict.get("tenant_id", "")) + inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) + elif isinstance(value, list): + value_list = cast(list[Any], value) + if all( + isinstance(item, dict) + and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY + for item in value_list + ): + file_list: list[File] = [] + for item in value_list: + if not isinstance(item, dict): + continue + item_dict = cast(dict[str, Any], item) + if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: + item_dict["tool_file_id"] = item_dict["related_id"] + elif item_dict["transfer_method"] in [ + FileTransferMethod.LOCAL_FILE, + FileTransferMethod.REMOTE_URL, + ]: + item_dict["upload_file_id"] = item_dict["related_id"] + tenant_id = cast(str, item_dict.get("tenant_id", "")) + file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) + inputs[key] = file_list return inputs @@ -685,8 +702,10 @@ class Conversation(Base): for k, v in inputs.items(): if isinstance(v, File): inputs[k] = v.model_dump() - elif isinstance(v, list) and all(isinstance(item, File) for item in v): - inputs[k] = [item.model_dump() for item in v] + elif isinstance(v, list): + v_list = cast(list[Any], v) + if all(isinstance(item, File) for item in v_list): + inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)] self._inputs = inputs @property @@ -826,7 +845,7 @@ class Conversation(Base): ) @property - def app(self): + def app(self) -> Optional[App]: return db.session.query(App).where(App.id == self.app_id).first() @property @@ -839,7 +858,7 @@ class Conversation(Base): return None @property - def from_account_name(self): + def from_account_name(self) -> Optional[str]: if self.from_account_id: account = db.session.query(Account).where(Account.id == self.from_account_id).first() if account: @@ -848,10 +867,10 @@ class Conversation(Base): return None @property - def in_debug_mode(self): + def in_debug_mode(self) -> bool: return self.override_model_configs is not None - def to_dict(self): + def to_dict(self) -> dict[str, Any]: return { "id": self.id, "app_id": self.app_id, @@ -897,7 +916,7 @@ class Message(Base): model_id = mapped_column(String(255), nullable=True) override_model_configs = mapped_column(sa.Text) conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False) - _inputs: Mapped[dict] = mapped_column("inputs", sa.JSON) + _inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON) query: Mapped[str] = mapped_column(sa.Text, nullable=False) message = mapped_column(sa.JSON, nullable=False) message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) @@ -924,28 +943,45 @@ class Message(Base): workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) @property - def inputs(self): + def inputs(self) -> dict[str, Any]: inputs = self._inputs.copy() for key, value in inputs.items(): # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. from factories import file_factory - if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: - if value["transfer_method"] == FileTransferMethod.TOOL_FILE: - value["tool_file_id"] = value["related_id"] - elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: - value["upload_file_id"] = value["related_id"] - inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"]) - elif isinstance(value, list) and all( - isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value + if ( + isinstance(value, dict) + and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY ): - inputs[key] = [] - for item in value: - if item["transfer_method"] == FileTransferMethod.TOOL_FILE: - item["tool_file_id"] = item["related_id"] - elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: - item["upload_file_id"] = item["related_id"] - inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) + value_dict = cast(dict[str, Any], value) + if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: + value_dict["tool_file_id"] = value_dict["related_id"] + elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: + value_dict["upload_file_id"] = value_dict["related_id"] + tenant_id = cast(str, value_dict.get("tenant_id", "")) + inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) + elif isinstance(value, list): + value_list = cast(list[Any], value) + if all( + isinstance(item, dict) + and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY + for item in value_list + ): + file_list: list[File] = [] + for item in value_list: + if not isinstance(item, dict): + continue + item_dict = cast(dict[str, Any], item) + if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: + item_dict["tool_file_id"] = item_dict["related_id"] + elif item_dict["transfer_method"] in [ + FileTransferMethod.LOCAL_FILE, + FileTransferMethod.REMOTE_URL, + ]: + item_dict["upload_file_id"] = item_dict["related_id"] + tenant_id = cast(str, item_dict.get("tenant_id", "")) + file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) + inputs[key] = file_list return inputs @inputs.setter @@ -954,8 +990,10 @@ class Message(Base): for k, v in inputs.items(): if isinstance(v, File): inputs[k] = v.model_dump() - elif isinstance(v, list) and all(isinstance(item, File) for item in v): - inputs[k] = [item.model_dump() for item in v] + elif isinstance(v, list): + v_list = cast(list[Any], v) + if all(isinstance(item, File) for item in v_list): + inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)] self._inputs = inputs @property @@ -1083,15 +1121,15 @@ class Message(Base): return None @property - def in_debug_mode(self): + def in_debug_mode(self) -> bool: return self.override_model_configs is not None @property - def message_metadata_dict(self): + def message_metadata_dict(self) -> dict[str, Any]: return json.loads(self.message_metadata) if self.message_metadata else {} @property - def agent_thoughts(self): + def agent_thoughts(self) -> list["MessageAgentThought"]: return ( db.session.query(MessageAgentThought) .where(MessageAgentThought.message_id == self.id) @@ -1100,11 +1138,11 @@ class Message(Base): ) @property - def retriever_resources(self): + def retriever_resources(self) -> Any | list[Any]: return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else [] @property - def message_files(self): + def message_files(self) -> list[dict[str, Any]]: from factories import file_factory message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all() @@ -1112,7 +1150,7 @@ class Message(Base): if not current_app: raise ValueError(f"App {self.app_id} not found") - files = [] + files: list[File] = [] for message_file in message_files: if message_file.transfer_method == FileTransferMethod.LOCAL_FILE.value: if message_file.upload_file_id is None: @@ -1159,7 +1197,7 @@ class Message(Base): ) files.append(file) - result = [ + result: list[dict[str, Any]] = [ {"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()} for (file, message_file) in zip(files, message_files) ] @@ -1176,7 +1214,7 @@ class Message(Base): return None - def to_dict(self): + def to_dict(self) -> dict[str, Any]: return { "id": self.id, "app_id": self.app_id, @@ -1200,7 +1238,7 @@ class Message(Base): } @classmethod - def from_dict(cls, data: dict): + def from_dict(cls, data: dict[str, Any]) -> "Message": return cls( id=data["id"], app_id=data["app_id"], @@ -1250,7 +1288,7 @@ class MessageFeedback(Base): account = db.session.query(Account).where(Account.id == self.from_account_id).first() return account - def to_dict(self): + def to_dict(self) -> dict[str, Any]: return { "id": str(self.id), "app_id": str(self.app_id), @@ -1435,7 +1473,18 @@ class EndUser(Base, UserMixin): type: Mapped[str] = mapped_column(String(255), nullable=False) external_user_id = mapped_column(String(255), nullable=True) name = mapped_column(String(255)) - is_anonymous: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) + _is_anonymous: Mapped[bool] = mapped_column( + "is_anonymous", sa.Boolean, nullable=False, server_default=sa.text("true") + ) + + @property + def is_anonymous(self) -> Literal[False]: + return False + + @is_anonymous.setter + def is_anonymous(self, value: bool) -> None: + self._is_anonymous = value + session_id: Mapped[str] = mapped_column() created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1461,7 +1510,7 @@ class AppMCPServer(Base): updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @staticmethod - def generate_server_code(n): + def generate_server_code(n: int) -> str: while True: result = generate_string(n) while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0: @@ -1518,7 +1567,7 @@ class Site(Base): self._custom_disclaimer = value @staticmethod - def generate_code(n): + def generate_code(n: int) -> str: while True: result = generate_string(n) while db.session.query(Site).where(Site.code == result).count() > 0: @@ -1549,7 +1598,7 @@ class ApiToken(Base): created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @staticmethod - def generate_api_key(prefix, n): + def generate_api_key(prefix: str, n: int) -> str: while True: result = prefix + generate_string(n) if db.session.scalar(select(exists().where(ApiToken.token == result))): @@ -1689,7 +1738,7 @@ class MessageAgentThought(Base): created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) @property - def files(self): + def files(self) -> list[Any]: if self.message_files: return cast(list[Any], json.loads(self.message_files)) else: @@ -1700,32 +1749,32 @@ class MessageAgentThought(Base): return self.tool.split(";") if self.tool else [] @property - def tool_labels(self): + def tool_labels(self) -> dict[str, Any]: try: if self.tool_labels_str: - return cast(dict, json.loads(self.tool_labels_str)) + return cast(dict[str, Any], json.loads(self.tool_labels_str)) else: return {} except Exception: return {} @property - def tool_meta(self): + def tool_meta(self) -> dict[str, Any]: try: if self.tool_meta_str: - return cast(dict, json.loads(self.tool_meta_str)) + return cast(dict[str, Any], json.loads(self.tool_meta_str)) else: return {} except Exception: return {} @property - def tool_inputs_dict(self): + def tool_inputs_dict(self) -> dict[str, Any]: tools = self.tools try: if self.tool_input: data = json.loads(self.tool_input) - result = {} + result: dict[str, Any] = {} for tool in tools: if tool in data: result[tool] = data[tool] @@ -1741,12 +1790,12 @@ class MessageAgentThought(Base): return {} @property - def tool_outputs_dict(self): + def tool_outputs_dict(self) -> dict[str, Any]: tools = self.tools try: if self.observation: data = json.loads(self.observation) - result = {} + result: dict[str, Any] = {} for tool in tools: if tool in data: result[tool] = data[tool] @@ -1844,14 +1893,14 @@ class TraceAppConfig(Base): is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) @property - def tracing_config_dict(self): + def tracing_config_dict(self) -> dict[str, Any]: return self.tracing_config or {} @property - def tracing_config_str(self): + def tracing_config_str(self) -> str: return json.dumps(self.tracing_config_dict) - def to_dict(self): + def to_dict(self) -> dict[str, Any]: return { "id": self.id, "app_id": self.app_id, diff --git a/api/models/provider.py b/api/models/provider.py index 18bf0ac5a..9a344ea56 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -17,7 +17,7 @@ class ProviderType(Enum): SYSTEM = "system" @staticmethod - def value_of(value): + def value_of(value: str) -> "ProviderType": for member in ProviderType: if member.value == value: return member @@ -35,7 +35,7 @@ class ProviderQuotaType(Enum): """hosted trial quota""" @staticmethod - def value_of(value): + def value_of(value: str) -> "ProviderQuotaType": for member in ProviderQuotaType: if member.value == value: return member diff --git a/api/models/tools.py b/api/models/tools.py index 8755570ee..09c8cd400 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,6 +1,6 @@ import json from datetime import datetime -from typing import Optional, cast +from typing import Any, Optional, cast from urllib.parse import urlparse import sqlalchemy as sa @@ -54,8 +54,8 @@ class ToolOAuthTenantClient(Base): encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False) @property - def oauth_params(self): - return cast(dict, json.loads(self.encrypted_oauth_params or "{}")) + def oauth_params(self) -> dict[str, Any]: + return cast(dict[str, Any], json.loads(self.encrypted_oauth_params or "{}")) class BuiltinToolProvider(Base): @@ -96,8 +96,8 @@ class BuiltinToolProvider(Base): expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1")) @property - def credentials(self): - return cast(dict, json.loads(self.encrypted_credentials)) + def credentials(self) -> dict[str, Any]: + return cast(dict[str, Any], json.loads(self.encrypted_credentials)) class ApiToolProvider(Base): @@ -146,8 +146,8 @@ class ApiToolProvider(Base): return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)] @property - def credentials(self): - return dict(json.loads(self.credentials_str)) + def credentials(self) -> dict[str, Any]: + return dict[str, Any](json.loads(self.credentials_str)) @property def user(self) -> Account | None: @@ -289,9 +289,9 @@ class MCPToolProvider(Base): return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() @property - def credentials(self): + def credentials(self) -> dict[str, Any]: try: - return cast(dict, json.loads(self.encrypted_credentials)) or {} + return cast(dict[str, Any], json.loads(self.encrypted_credentials)) or {} except Exception: return {} @@ -327,12 +327,12 @@ class MCPToolProvider(Base): return mask_url(self.decrypted_server_url) @property - def decrypted_credentials(self): + def decrypted_credentials(self) -> dict[str, Any]: from core.helper.provider_cache import NoOpProviderCredentialCache from core.tools.mcp_tool.provider import MCPToolProviderController from core.tools.utils.encryption import create_provider_encrypter - provider_controller = MCPToolProviderController._from_db(self) + provider_controller = MCPToolProviderController.from_db(self) encrypter, _ = create_provider_encrypter( tenant_id=self.tenant_id, @@ -340,7 +340,7 @@ class MCPToolProvider(Base): cache=NoOpProviderCredentialCache(), ) - return encrypter.decrypt(self.credentials) # type: ignore + return encrypter.decrypt(self.credentials) class ToolModelInvoke(Base): diff --git a/api/models/types.py b/api/models/types.py index e5581c3ab..cc69ae4f5 100644 --- a/api/models/types.py +++ b/api/models/types.py @@ -1,29 +1,34 @@ import enum -from typing import Generic, TypeVar +import uuid +from typing import Any, Generic, TypeVar from sqlalchemy import CHAR, VARCHAR, TypeDecorator from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.sql.type_api import TypeEngine -class StringUUID(TypeDecorator): +class StringUUID(TypeDecorator[uuid.UUID | str | None]): impl = CHAR cache_ok = True - def process_bind_param(self, value, dialect): + def process_bind_param(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None: if value is None: return value elif dialect.name == "postgresql": return str(value) else: - return value.hex + if isinstance(value, uuid.UUID): + return value.hex + return value - def load_dialect_impl(self, dialect): + def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: if dialect.name == "postgresql": return dialect.type_descriptor(UUID()) else: return dialect.type_descriptor(CHAR(36)) - def process_result_value(self, value, dialect): + def process_result_value(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None: if value is None: return value return str(value) @@ -32,7 +37,7 @@ class StringUUID(TypeDecorator): _E = TypeVar("_E", bound=enum.StrEnum) -class EnumText(TypeDecorator, Generic[_E]): +class EnumText(TypeDecorator[_E | None], Generic[_E]): impl = VARCHAR cache_ok = True @@ -50,28 +55,25 @@ class EnumText(TypeDecorator, Generic[_E]): # leave some rooms for future longer enum values. self._length = max(max_enum_value_len, 20) - def process_bind_param(self, value: _E | str | None, dialect): + def process_bind_param(self, value: _E | str | None, dialect: Dialect) -> str | None: if value is None: return value if isinstance(value, self._enum_class): return value.value - elif isinstance(value, str): - self._enum_class(value) - return value - else: - raise TypeError(f"expected str or {self._enum_class}, got {type(value)}") + # Since _E is bound to StrEnum which inherits from str, at this point value must be str + self._enum_class(value) + return value - def load_dialect_impl(self, dialect): + def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: return dialect.type_descriptor(VARCHAR(self._length)) - def process_result_value(self, value, dialect) -> _E | None: + def process_result_value(self, value: str | None, dialect: Dialect) -> _E | None: if value is None: return value - if not isinstance(value, str): - raise TypeError(f"expected str, got {type(value)}") + # Type annotation guarantees value is str at this point return self._enum_class(value) - def compare_values(self, x, y): + def compare_values(self, x: _E | None, y: _E | None) -> bool: if x is None or y is None: return x is y return x == y diff --git a/api/models/workflow.py b/api/models/workflow.py index 23f18929d..4686b38b0 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -3,7 +3,7 @@ import logging from collections.abc import Mapping, Sequence from datetime import datetime from enum import Enum, StrEnum -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union, cast from uuid import uuid4 import sqlalchemy as sa @@ -224,7 +224,7 @@ class Workflow(Base): raise WorkflowDataError("nodes not found in workflow graph") try: - node_config = next(filter(lambda node: node["id"] == node_id, nodes)) + node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes)) except StopIteration: raise NodeNotFoundError(node_id) assert isinstance(node_config, dict) @@ -289,7 +289,7 @@ class Workflow(Base): def features_dict(self) -> dict[str, Any]: return json.loads(self.features) if self.features else {} - def user_input_form(self, to_old_structure: bool = False): + def user_input_form(self, to_old_structure: bool = False) -> list[Any]: # get start node from graph if not self.graph: return [] @@ -306,7 +306,7 @@ class Workflow(Base): variables: list[Any] = start_node.get("data", {}).get("variables", []) if to_old_structure: - old_structure_variables = [] + old_structure_variables: list[dict[str, Any]] = [] for variable in variables: old_structure_variables.append({variable["type"]: variable}) @@ -346,9 +346,7 @@ class Workflow(Base): @property def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: - # TODO: find some way to init `self._environment_variables` when instance created. - if self._environment_variables is None: - self._environment_variables = "{}" + # _environment_variables is guaranteed to be non-None due to server_default="{}" # Use workflow.tenant_id to avoid relying on request user in background threads tenant_id = self.tenant_id @@ -362,17 +360,18 @@ class Workflow(Base): ] # decrypt secret variables value - def decrypt_func(var): + def decrypt_func(var: Variable) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable: if isinstance(var, SecretVariable): return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)): return var else: - raise AssertionError("this statement should be unreachable.") + # Other variable types are not supported for environment variables + raise AssertionError(f"Unexpected variable type for environment variable: {type(var)}") - decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list( - map(decrypt_func, results) - ) + decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = [ + decrypt_func(var) for var in results + ] return decrypted_results @environment_variables.setter @@ -400,7 +399,7 @@ class Workflow(Base): value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name}) # encrypt secret variables value - def encrypt_func(var): + def encrypt_func(var: Variable) -> Variable: if isinstance(var, SecretVariable): return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) else: @@ -430,9 +429,7 @@ class Workflow(Base): @property def conversation_variables(self) -> Sequence[Variable]: - # TODO: find some way to init `self._conversation_variables` when instance created. - if self._conversation_variables is None: - self._conversation_variables = "{}" + # _conversation_variables is guaranteed to be non-None due to server_default="{}" variables_dict: dict[str, Any] = json.loads(self._conversation_variables) results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()] @@ -577,7 +574,7 @@ class WorkflowRun(Base): } @classmethod - def from_dict(cls, data: dict) -> "WorkflowRun": + def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun": return cls( id=data.get("id"), tenant_id=data.get("tenant_id"), @@ -662,7 +659,8 @@ class WorkflowNodeExecutionModel(Base): __tablename__ = "workflow_node_executions" @declared_attr - def __table_args__(cls): # noqa + @classmethod + def __table_args__(cls) -> Any: return ( PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), Index( @@ -699,7 +697,7 @@ class WorkflowNodeExecutionModel(Base): # MyPy may flag the following line because it doesn't recognize that # the `declared_attr` decorator passes the receiving class as the first # argument to this method, allowing us to reference class attributes. - cls.created_at.desc(), # type: ignore + cls.created_at.desc(), ), ) @@ -761,15 +759,15 @@ class WorkflowNodeExecutionModel(Base): return json.loads(self.execution_metadata) if self.execution_metadata else {} @property - def extras(self): + def extras(self) -> dict[str, Any]: from core.tools.tool_manager import ToolManager - extras = {} + extras: dict[str, Any] = {} if self.execution_metadata_dict: from core.workflow.nodes import NodeType if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict: - tool_info = self.execution_metadata_dict["tool_info"] + tool_info: dict[str, Any] = self.execution_metadata_dict["tool_info"] extras["icon"] = ToolManager.get_tool_icon( tenant_id=self.tenant_id, provider_type=tool_info["provider_type"], @@ -1037,7 +1035,7 @@ class WorkflowDraftVariable(Base): # making this attribute harder to access from outside the class. __value: Segment | None - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: """ The constructor of `WorkflowDraftVariable` is not intended for direct use outside this file. Its solo purpose is setup private state @@ -1055,15 +1053,15 @@ class WorkflowDraftVariable(Base): self.__value = None def get_selector(self) -> list[str]: - selector = json.loads(self.selector) + selector: Any = json.loads(self.selector) if not isinstance(selector, list): logger.error( "invalid selector loaded from database, type=%s, value=%s", - type(selector), + type(selector).__name__, self.selector, ) raise ValueError("invalid selector.") - return selector + return cast(list[str], selector) def _set_selector(self, value: list[str]): self.selector = json.dumps(value) @@ -1086,15 +1084,17 @@ class WorkflowDraftVariable(Base): # `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging. if isinstance(value, dict): if not maybe_file_object(value): - return value + return cast(Any, value) return File.model_validate(value) elif isinstance(value, list) and value: - first = value[0] + value_list = cast(list[Any], value) + first: Any = value_list[0] if not maybe_file_object(first): - return value - return [File.model_validate(i) for i in value] + return cast(Any, value) + file_list: list[File] = [File.model_validate(cast(dict[str, Any], i)) for i in value_list] + return cast(Any, file_list) else: - return value + return cast(Any, value) @classmethod def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment: diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index 8694f44fa..059b8bba4 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -6,7 +6,6 @@ "tests/", "migrations/", ".venv/", - "models/", "core/", "controllers/", "tasks/", diff --git a/api/services/agent_service.py b/api/services/agent_service.py index 72833b9d6..76267a2fe 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -1,5 +1,5 @@ import threading -from typing import Optional +from typing import Any, Optional import pytz from flask_login import current_user @@ -68,7 +68,7 @@ class AgentService: if not app_model_config: raise ValueError("App model config not found") - result = { + result: dict[str, Any] = { "meta": { "status": "success", "executor": executor, diff --git a/api/services/app_service.py b/api/services/app_service.py index 4502fa929..09aab5f0c 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -171,6 +171,8 @@ class AppService: # get original app model config if app.mode == AppMode.AGENT_CHAT.value or app.is_agent: model_config = app.app_model_config + if not model_config: + return app agent_mode = model_config.agent_mode_dict # decrypt agent tool parameters if it's secret-input for tool in agent_mode.get("tools") or []: @@ -205,7 +207,8 @@ class AppService: pass # override agent mode - model_config.agent_mode = json.dumps(agent_mode) + if model_config: + model_config.agent_mode = json.dumps(agent_mode) class ModifiedApp(App): """ diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 0084eebb3..9b1999d81 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -12,7 +12,7 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models.enums import MessageStatus -from models.model import App, AppMode, AppModelConfig, Message +from models.model import App, AppMode, Message from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, @@ -40,7 +40,9 @@ class AudioService: if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"): raise ValueError("Speech to text is not enabled") else: - app_model_config: AppModelConfig = app_model.app_model_config + app_model_config = app_model.app_model_config + if not app_model_config: + raise ValueError("Speech to text is not enabled") if not app_model_config.speech_to_text_dict["enabled"]: raise ValueError("Speech to text is not enabled") diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index e0885f325..c0c97fbd7 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -973,7 +973,7 @@ class DocumentService: file_ids = [ document.data_source_info_dict["upload_file_id"] for document in documents - if document.data_source_type == "upload_file" + if document.data_source_type == "upload_file" and document.data_source_info_dict ] batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) @@ -1067,8 +1067,9 @@ class DocumentService: # sync document indexing document.indexing_status = "waiting" data_source_info = document.data_source_info_dict - data_source_info["mode"] = "scrape" - document.data_source_info = json.dumps(data_source_info, ensure_ascii=False) + if data_source_info: + data_source_info["mode"] = "scrape" + document.data_source_info = json.dumps(data_source_info, ensure_ascii=False) db.session.add(document) db.session.commit() diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 783d6c242..3262a0066 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -114,8 +114,9 @@ class ExternalDatasetService: ) if external_knowledge_api is None: raise ValueError("api template not found") - if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE: - args.get("settings")["api_key"] = external_knowledge_api.settings_dict.get("api_key") + settings = args.get("settings") + if settings and settings.get("api_key") == HIDDEN_VALUE and external_knowledge_api.settings_dict: + settings["api_key"] = external_knowledge_api.settings_dict.get("api_key") external_knowledge_api.name = args.get("name") external_knowledge_api.description = args.get("description", "") diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index 665ef27d6..b557d2155 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -226,7 +226,7 @@ class MCPToolManageService: def update_mcp_provider_credentials( cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False ): - provider_controller = MCPToolProviderController._from_db(mcp_provider) + provider_controller = MCPToolProviderController.from_db(mcp_provider) tool_configuration = ProviderConfigEncrypter( tenant_id=mcp_provider.tenant_id, config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type] diff --git a/api/tests/unit_tests/models/test_types_enum_text.py b/api/tests/unit_tests/models/test_types_enum_text.py index e4061b72c..c59afcf0d 100644 --- a/api/tests/unit_tests/models/test_types_enum_text.py +++ b/api/tests/unit_tests/models/test_types_enum_text.py @@ -154,7 +154,7 @@ class TestEnumText: TestCase( name="session insert with invalid type", action=lambda s: _session_insert_with_value(s, 1), - exc_type=TypeError, + exc_type=ValueError, ), TestCase( name="insert with invalid value", @@ -164,7 +164,7 @@ class TestEnumText: TestCase( name="insert with invalid type", action=lambda s: _insert_with_user(s, 1), - exc_type=TypeError, + exc_type=ValueError, ), ] for idx, c in enumerate(cases, 1):