diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 3de2f5ca9..8d256da9c 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -140,7 +140,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): environment_variables=self._workflow.environment_variables, # Based on the definition of `VariableUnion`, # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. - conversation_variables=cast(list[VariableUnion], conversation_variables), + conversation_variables=conversation_variables, ) # init graph diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index cac7e8e6e..383a2dd57 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -3,7 +3,7 @@ import base64 from libs import rsa -def obfuscated_token(token: str): +def obfuscated_token(token: str) -> str: if not token: return token if len(token) <= 8: diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 51af3d187..e56756554 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -158,8 +158,6 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, LargeLanguageModel): raise Exception("Model type instance is not LargeLanguageModel") - - self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) return cast( Union[LLMResult, Generator], self._round_robin_invoke( @@ -188,8 +186,6 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, LargeLanguageModel): raise Exception("Model type instance is not LargeLanguageModel") - - self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) return cast( int, self._round_robin_invoke( @@ -214,8 +210,6 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") - - self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) return cast( TextEmbeddingResult, self._round_robin_invoke( @@ -237,8 +231,6 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") - - self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) return cast( list[int], self._round_robin_invoke( @@ -269,8 +261,6 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, RerankModel): raise Exception("Model type instance is not RerankModel") - - self.model_type_instance = cast(RerankModel, self.model_type_instance) return cast( RerankResult, self._round_robin_invoke( @@ -295,8 +285,6 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, ModerationModel): raise Exception("Model type instance is not ModerationModel") - - self.model_type_instance = cast(ModerationModel, self.model_type_instance) return cast( bool, self._round_robin_invoke( @@ -318,8 +306,6 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, Speech2TextModel): raise Exception("Model type instance is not Speech2TextModel") - - self.model_type_instance = cast(Speech2TextModel, self.model_type_instance) return cast( str, self._round_robin_invoke( @@ -343,8 +329,6 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, TTSModel): raise Exception("Model type instance is not TTSModel") - - self.model_type_instance = cast(TTSModel, self.model_type_instance) return cast( Iterable[bytes], self._round_robin_invoke( @@ -404,8 +388,6 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, TTSModel): raise Exception("Model type instance is not TTSModel") - - self.model_type_instance = cast(TTSModel, self.model_type_instance) return self.model_type_instance.get_tts_model_voices( model=self.model, credentials=self.credentials, language=language ) diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index 2f4e65146..cdc6ccc82 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -87,7 +87,6 @@ class PromptMessageUtil: if isinstance(prompt_message.content, list): for content in prompt_message.content: if content.type == PromptMessageContentType.TEXT: - content = cast(TextPromptMessageContent, content) text += content.data else: content = cast(ImagePromptMessageContent, content) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 28a4ce077..cad0de647 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -2,7 +2,7 @@ import contextlib import json from collections import defaultdict from json import JSONDecodeError -from typing import Any, Optional, cast +from typing import Any, Optional from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -154,8 +154,8 @@ class ProviderManager: for provider_entity in provider_entities: # handle include, exclude if is_filtered( - include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET), - exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET), + include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, + exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, data=provider_entity, name_func=lambda x: x.provider, ): diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index fcf3a6d12..41ad5e57e 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -3,7 +3,7 @@ import os import uuid from collections.abc import Generator, Iterable, Sequence from itertools import islice -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union import qdrant_client from flask import current_app @@ -426,7 +426,6 @@ class QdrantVector(BaseVector): def _reload_if_needed(self): if isinstance(self._client, QdrantLocal): - self._client = cast(QdrantLocal, self._client) self._client._load() @classmethod diff --git a/api/core/rag/extractor/markdown_extractor.py b/api/core/rag/extractor/markdown_extractor.py index c97765b1d..3845392c8 100644 --- a/api/core/rag/extractor/markdown_extractor.py +++ b/api/core/rag/extractor/markdown_extractor.py @@ -2,7 +2,7 @@ import re from pathlib import Path -from typing import Optional, cast +from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.helpers import detect_file_encodings @@ -76,7 +76,7 @@ class MarkdownExtractor(BaseExtractor): markdown_tups.append((current_header, current_text)) markdown_tups = [ - (re.sub(r"#", "", cast(str, key)).strip() if key else None, re.sub(r"<.*?>", "", value)) + (re.sub(r"#", "", key).strip() if key else None, re.sub(r"<.*?>", "", value)) for key, value in markdown_tups ] diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 17f4d1af2..3d4b898c9 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -385,4 +385,4 @@ class NotionExtractor(BaseExtractor): f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}" ) - return cast(str, data_source_binding.access_token) + return data_source_binding.access_token diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index 7dfe2e357..3c43f3410 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -2,7 +2,7 @@ import contextlib from collections.abc import Iterator -from typing import Optional, cast +from typing import Optional from core.rag.extractor.blob.blob import Blob from core.rag.extractor.extractor_base import BaseExtractor @@ -27,7 +27,7 @@ class PdfExtractor(BaseExtractor): plaintext_file_exists = False if self._file_cache_key: with contextlib.suppress(FileNotFoundError): - text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8") + text = storage.load(self._file_cache_key).decode("utf-8") plaintext_file_exists = True return [Document(page_content=text)] documents = list(self.load()) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 3454ec348..b338a779a 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -331,16 +331,13 @@ class ToolManager: if controller_tools is None or len(controller_tools) == 0: raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - return cast( - WorkflowTool, - controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - credentials={}, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ), + return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + credentials={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) ) elif provider_type == ToolProviderType.APP: raise NotImplementedError("app provider not implemented") @@ -648,8 +645,8 @@ class ToolManager: for provider in builtin_providers: # handle include, exclude if is_filtered( - include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET), - exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET), + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, data=provider, name_func=lambda x: x.identity.name, ): diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 8357dac0d..bf075bd73 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -3,7 +3,7 @@ from collections.abc import Generator from datetime import date, datetime from decimal import Decimal from mimetypes import guess_extension -from typing import Optional, cast +from typing import Optional from uuid import UUID import numpy as np @@ -159,8 +159,7 @@ class ToolFileMessageTransformer: elif message.type == ToolInvokeMessage.MessageType.JSON: if isinstance(message.message, ToolInvokeMessage.JsonMessage): - json_msg = cast(ToolInvokeMessage.JsonMessage, message.message) - json_msg.json_object = safe_json_value(json_msg.json_object) + message.message.json_object = safe_json_value(message.message.json_object) yield message else: yield message diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 3f59b3f47..251d91480 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -129,17 +129,14 @@ class ModelInvocationUtils: db.session.commit() try: - response: LLMResult = cast( - LLMResult, - model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=[], - stop=[], - stream=False, - user=user_id, - callbacks=[], - ), + response: LLMResult = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=[], + stop=[], + stream=False, + user=user_id, + callbacks=[], ) except InvokeRateLimitError as e: raise InvokeModelError(f"Invoke rate limit error: {e}") diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 1387df597..ea219af68 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -1,7 +1,7 @@ import json import logging from collections.abc import Generator -from typing import Any, Optional, cast +from typing import Any, Optional from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.tools.__base.tool import Tool @@ -204,14 +204,14 @@ class WorkflowTool(Tool): item = self._update_file_mapping(item) file = build_from_mapping( mapping=item, - tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id), + tenant_id=str(self.runtime.tenant_id), ) files.append(file) elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: value = self._update_file_mapping(value) file = build_from_mapping( mapping=value, - tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id), + tenant_id=str(self.runtime.tenant_id), ) files.append(file) diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index 16c8116ac..a994730cd 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Annotated, TypeAlias, cast +from typing import Annotated, TypeAlias from uuid import uuid4 from pydantic import Discriminator, Field, Tag @@ -86,7 +86,7 @@ class SecretVariable(StringVariable): @property def log(self) -> str: - return cast(str, encrypter.obfuscated_token(self.value)) + return encrypter.obfuscated_token(self.value) class NoneVariable(NoneSegment, Variable): diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 03b920ccb..188d0c475 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -374,7 +374,7 @@ class GraphEngine: if len(sub_edge_mappings) == 0: continue - edge = cast(GraphEdge, sub_edge_mappings[0]) + edge = sub_edge_mappings[0] if edge.run_condition is None: logger.warning("Edge %s run condition is None", edge.target_node_id) continue diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 144f036aa..9e5d5e62b 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -153,7 +153,7 @@ class AgentNode(BaseNode): messages=message_stream, tool_info={ "icon": self.agent_strategy_icon, - "agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name, + "agent_strategy": self._node_data.agent_strategy_name, }, parameters_for_log=parameters_for_log, user_id=self.user_id, @@ -394,8 +394,7 @@ class AgentNode(BaseNode): current_plugin = next( plugin for plugin in plugins - if f"{plugin.plugin_id}/{plugin.name}" - == cast(AgentNodeData, self._node_data).agent_strategy_provider_name + if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name ) icon = current_plugin.declaration.icon except StopIteration: diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index b820999c3..bb09b1a5d 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -302,12 +302,12 @@ def _extract_text_from_yaml(file_content: bytes) -> str: encoding = "utf-8" yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore")) - return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)) + return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e: # If decoding fails, try with utf-8 as last resort try: yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore")) - return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)) + return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) except (UnicodeDecodeError, yaml.YAMLError): raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 3dcde5ad8..43edf7eac 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -139,7 +139,7 @@ class ParameterExtractorNode(BaseNode): """ Run the node. """ - node_data = cast(ParameterExtractorNodeData, self._node_data) + node_data = self._node_data variable = self.graph_runtime_state.variable_pool.get(node_data.query) query = variable.text if variable else "" diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 3e4984ecd..ba4e55bb8 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -109,7 +109,7 @@ class QuestionClassifierNode(BaseNode): return "1" def _run(self): - node_data = cast(QuestionClassifierNodeData, self._node_data) + node_data = self._node_data variable_pool = self.graph_runtime_state.variable_pool # extract variables diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 4c8e13de7..1a85c08b5 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, Optional from sqlalchemy import select from sqlalchemy.orm import Session @@ -57,7 +57,7 @@ class ToolNode(BaseNode): Run the tool node """ - node_data = cast(ToolNodeData, self._node_data) + node_data = self._node_data # fetch tool icon tool_info = { diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 801e36e27..e9b73df0f 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -2,7 +2,7 @@ import logging import time import uuid from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, Optional from configs import dify_config from core.app.apps.exc import GenerateTaskStoppedError @@ -261,7 +261,6 @@ class WorkflowEntry: environment_variables=[], ) - node_cls = cast(type[BaseNode], node_cls) # init workflow run state node: BaseNode = node_cls( id=str(uuid.uuid4()), diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 0ea7d3ae1..62e3bfa3b 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -3,7 +3,7 @@ import os import urllib.parse import uuid from collections.abc import Callable, Mapping, Sequence -from typing import Any, cast +from typing import Any import httpx from sqlalchemy import select @@ -258,7 +258,6 @@ def _get_remote_file_info(url: str): mime_type = "" resp = ssrf_proxy.head(url, follow_redirects=True) - resp = cast(httpx.Response, resp) if resp.status_code == httpx.codes.OK: if content_disposition := resp.headers.get("Content-Disposition"): filename = str(content_disposition.split("filename=")[-1].strip('"')) diff --git a/api/models/tools.py b/api/models/tools.py index e0c9fa6ff..d88d81737 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -308,7 +308,7 @@ class MCPToolProvider(Base): @property def decrypted_server_url(self) -> str: - return cast(str, encrypter.decrypt_token(self.tenant_id, self.server_url)) + return encrypter.decrypt_token(self.tenant_id, self.server_url) @property def masked_server_url(self) -> str: diff --git a/api/services/account_service.py b/api/services/account_service.py index 089e66716..50ce171de 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -146,7 +146,7 @@ class AccountService: account.last_active_at = naive_utc_now() db.session.commit() - return cast(Account, account) + return account @staticmethod def get_account_jwt_token(account: Account) -> str: @@ -191,7 +191,7 @@ class AccountService: db.session.commit() - return cast(Account, account) + return account @staticmethod def update_account_password(account, password, new_password): @@ -1127,7 +1127,7 @@ class TenantService: def get_custom_config(tenant_id: str) -> dict: tenant = db.get_or_404(Tenant, tenant_id) - return cast(dict, tenant.custom_config_dict) + return tenant.custom_config_dict @staticmethod def is_owner(account: Account, tenant: Tenant) -> bool: diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 6603063c2..9ee92bc2d 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -1,5 +1,5 @@ import uuid -from typing import cast +from typing import Optional import pandas as pd from flask_login import current_user @@ -40,7 +40,7 @@ class AppAnnotationService: if not message: raise NotFound("Message Not Exists.") - annotation = message.annotation + annotation: Optional[MessageAnnotation] = message.annotation # save the message annotation if annotation: annotation.content = args["answer"] @@ -70,7 +70,7 @@ class AppAnnotationService: app_id, annotation_setting.collection_binding_id, ) - return cast(MessageAnnotation, annotation) + return annotation @classmethod def enable_app_annotation(cls, args: dict, app_id: str) -> dict: diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 4f659c5e1..eb85d6118 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -1,7 +1,6 @@ import time import uuid from os import getenv -from typing import cast import pytest @@ -13,7 +12,6 @@ from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.code.entities import CodeNodeData from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.workflow import WorkflowType @@ -238,8 +236,6 @@ def test_execute_code_output_validator_depth(): "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, } - node._node_data = cast(CodeNodeData, node._node_data) - # validate node._transform_result(result, node._node_data.outputs) @@ -334,8 +330,6 @@ def test_execute_code_output_object_list(): ] } - node._node_data = cast(CodeNodeData, node._node_data) - # validate node._transform_result(result, node._node_data.outputs)