From a78339a040d5074333c0187bca7c724fb8bb95e4 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Sat, 6 Sep 2025 04:32:23 +0900 Subject: [PATCH] remove bare list, dict, Sequence, None, Any (#25058) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- --- api/configs/remote_settings_sources/base.py | 2 +- .../remote_settings_sources/nacos/__init__.py | 2 +- api/controllers/console/app/generator.py | 2 +- .../console/app/workflow_draft_variable.py | 6 +-- api/controllers/mcp/mcp.py | 2 +- api/core/agent/base_agent_runner.py | 2 +- api/core/agent/cot_agent_runner.py | 2 +- api/core/agent/entities.py | 2 +- .../easy_ui_based_app/dataset/manager.py | 2 +- .../easy_ui_based_app/model_config/manager.py | 2 +- .../prompt_template/manager.py | 2 +- .../apps/advanced_chat/app_config_manager.py | 2 +- .../app/apps/advanced_chat/app_generator.py | 2 +- api/core/app/apps/advanced_chat/app_runner.py | 6 +-- .../advanced_chat/generate_task_pipeline.py | 6 +-- .../app/apps/agent_chat/app_config_manager.py | 2 +- api/core/app/apps/agent_chat/app_generator.py | 2 +- api/core/app/apps/agent_chat/app_runner.py | 2 +- .../agent_chat/generate_response_converter.py | 4 +- .../base_app_generate_response_converter.py | 2 +- api/core/app/apps/base_app_generator.py | 2 +- api/core/app/apps/base_app_queue_manager.py | 12 ++--- api/core/app/apps/base_app_runner.py | 10 ++-- api/core/app/apps/chat/app_config_manager.py | 2 +- api/core/app/apps/chat/app_generator.py | 2 +- api/core/app/apps/chat/app_runner.py | 2 +- .../apps/chat/generate_response_converter.py | 4 +- .../common/workflow_response_converter.py | 2 +- .../app/apps/completion/app_config_manager.py | 2 +- api/core/app/apps/completion/app_generator.py | 2 +- api/core/app/apps/completion/app_runner.py | 2 +- .../completion/generate_response_converter.py | 4 +- .../apps/message_based_app_queue_manager.py | 4 +- .../app/apps/workflow/app_config_manager.py | 2 +- api/core/app/apps/workflow/app_generator.py | 2 +- .../app/apps/workflow/app_queue_manager.py | 4 +- api/core/app/apps/workflow/app_runner.py | 4 +- .../workflow/generate_response_converter.py | 4 +- .../apps/workflow/generate_task_pipeline.py | 6 +-- api/core/app/apps/workflow_app_runner.py | 6 +-- .../based_generate_task_pipeline.py | 2 +- .../easy_ui_based_generate_task_pipeline.py | 6 +-- .../task_pipeline/message_cycle_manager.py | 4 +- .../agent_tool_callback_handler.py | 14 +++--- .../index_tool_callback_handler.py | 6 +-- api/core/entities/model_entities.py | 4 +- api/core/entities/provider_configuration.py | 34 +++++++------- api/core/errors/error.py | 2 +- .../api_based_extension_requestor.py | 4 +- api/core/extension/extensible.py | 2 +- api/core/external_data_tool/api/api.py | 2 +- api/core/external_data_tool/base.py | 4 +- api/core/external_data_tool/factory.py | 4 +- api/core/file/tool_file_parser.py | 2 +- .../code_executor/code_node_provider.py | 2 +- .../jinja2/jinja2_transformer.py | 2 +- .../python3/python3_code_provider.py | 2 +- api/core/helper/model_provider_cache.py | 4 +- api/core/helper/provider_cache.py | 8 ++-- api/core/helper/tool_parameter_cache.py | 4 +- api/core/helper/trace_id_helper.py | 2 +- api/core/hosting_configuration.py | 4 +- api/core/indexing_runner.py | 6 +-- api/core/llm_generator/llm_generator.py | 14 +++--- .../output_parser/rule_config_generator.py | 4 +- .../output_parser/structured_output.py | 10 ++-- .../suggested_questions_after_answer.py | 3 +- api/core/mcp/auth/auth_provider.py | 6 +-- api/core/mcp/client/sse_client.py | 16 +++---- api/core/mcp/client/streamable_client.py | 24 +++++----- api/core/mcp/session/base_session.py | 30 ++++++------ api/core/mcp/session/client_session.py | 22 ++++----- api/core/memory/token_buffer_memory.py | 2 +- api/core/model_manager.py | 14 +++--- .../model_runtime/callbacks/base_callback.py | 8 ++-- .../callbacks/logging_callback.py | 6 +-- api/core/model_runtime/errors/invoke.py | 2 +- .../model_providers/__base/ai_model.py | 2 +- .../__base/large_language_model.py | 8 ++-- .../__base/tokenizers/gpt2_tokenizer.py | 2 +- .../model_providers/__base/tts_model.py | 2 +- .../model_providers/model_provider_factory.py | 8 ++-- .../schema_validators/common_validator.py | 2 +- .../model_credential_schema_validator.py | 2 +- .../provider_credential_schema_validator.py | 2 +- api/core/model_runtime/utils/encoders.py | 4 +- api/core/moderation/api/api.py | 4 +- api/core/moderation/base.py | 6 +-- api/core/moderation/factory.py | 4 +- api/core/moderation/keywords/keywords.py | 2 +- .../openai_moderation/openai_moderation.py | 2 +- api/core/moderation/output_moderation.py | 2 +- .../plugin/backwards_invocation/encrypt.py | 2 +- api/core/plugin/backwards_invocation/node.py | 4 +- api/core/plugin/entities/plugin.py | 8 ++-- api/core/plugin/impl/agent.py | 4 +- api/core/plugin/impl/exc.py | 2 +- api/core/plugin/impl/model.py | 2 +- api/core/plugin/impl/tool.py | 4 +- api/core/plugin/utils/chunk_merger.py | 2 +- api/core/prompt/advanced_prompt_transform.py | 2 +- api/core/prompt/simple_prompt_transform.py | 4 +- api/core/prompt/utils/prompt_message_util.py | 2 +- .../prompt/utils/prompt_template_parser.py | 2 +- api/core/provider_manager.py | 2 +- .../rag/datasource/keyword/jieba/jieba.py | 10 ++-- .../rag/datasource/keyword/keyword_base.py | 4 +- .../rag/datasource/keyword/keyword_factory.py | 4 +- .../vdb/analyticdb/analyticdb_vector.py | 6 +-- .../analyticdb/analyticdb_vector_openapi.py | 14 +++--- .../vdb/analyticdb/analyticdb_vector_sql.py | 12 ++--- .../rag/datasource/vdb/baidu/baidu_vector.py | 12 ++--- .../datasource/vdb/chroma/chroma_vector.py | 2 +- .../vdb/clickzetta/clickzetta_vector.py | 26 +++++------ .../vdb/couchbase/couchbase_vector.py | 6 +-- .../vdb/elasticsearch/elasticsearch_vector.py | 8 ++-- .../vdb/huawei/huawei_cloud_vector.py | 8 ++-- .../datasource/vdb/lindorm/lindorm_vector.py | 12 ++--- .../vdb/matrixone/matrixone_vector.py | 8 ++-- .../datasource/vdb/milvus/milvus_vector.py | 8 ++-- .../datasource/vdb/myscale/myscale_vector.py | 6 +-- .../vdb/oceanbase/oceanbase_vector.py | 10 ++-- .../rag/datasource/vdb/opengauss/opengauss.py | 8 ++-- .../vdb/opensearch/opensearch_vector.py | 6 +-- .../rag/datasource/vdb/oracle/oraclevector.py | 8 ++-- .../datasource/vdb/pgvecto_rs/pgvecto_rs.py | 6 +-- .../rag/datasource/vdb/pgvector/pgvector.py | 8 ++-- .../vdb/pyvastbase/vastbase_vector.py | 8 ++-- .../datasource/vdb/qdrant/qdrant_vector.py | 4 +- .../rag/datasource/vdb/relyt/relyt_vector.py | 8 ++-- .../vdb/tablestore/tablestore_vector.py | 18 ++++---- .../datasource/vdb/tencent/tencent_vector.py | 10 ++-- .../tidb_on_qdrant/tidb_on_qdrant_vector.py | 4 +- .../datasource/vdb/tidb_vector/tidb_vector.py | 8 ++-- .../datasource/vdb/upstash/upstash_vector.py | 10 ++-- api/core/rag/datasource/vdb/vector_base.py | 6 +-- api/core/rag/datasource/vdb/vector_factory.py | 8 ++-- .../vdb/vikingdb/vikingdb_vector.py | 6 +-- .../vdb/weaviate/weaviate_vector.py | 10 ++-- api/core/rag/docstore/dataset_docstore.py | 10 ++-- api/core/rag/embedding/cached_embedding.py | 2 +- .../rag/extractor/entity/extract_setting.py | 4 +- .../rag/extractor/firecrawl/firecrawl_app.py | 2 +- api/core/rag/extractor/helpers.py | 2 +- api/core/rag/extractor/watercrawl/provider.py | 8 ++-- api/core/rag/extractor/word_extractor.py | 2 +- api/core/rag/rerank/rerank_model.py | 2 +- api/core/rag/rerank/weight_rerank.py | 2 +- api/core/rag/retrieval/dataset_retrieval.py | 4 +- api/core/rag/splitter/text_splitter.py | 6 +-- .../celery_workflow_execution_repository.py | 2 +- ...lery_workflow_node_execution_repository.py | 2 +- ...qlalchemy_workflow_execution_repository.py | 2 +- ...hemy_workflow_node_execution_repository.py | 8 ++-- api/core/tools/__base/tool.py | 2 +- api/core/tools/__base/tool_provider.py | 4 +- api/core/tools/builtin_tool/provider.py | 6 +-- .../builtin_tool/providers/audio/audio.py | 2 +- .../tools/builtin_tool/providers/code/code.py | 2 +- .../tools/builtin_tool/providers/time/time.py | 2 +- .../providers/webscraper/webscraper.py | 2 +- api/core/tools/custom_tool/provider.py | 2 +- api/core/tools/custom_tool/tool.py | 4 +- api/core/tools/entities/api_entities.py | 4 +- api/core/tools/entities/common_entities.py | 2 +- api/core/tools/entities/tool_entities.py | 4 +- api/core/tools/mcp_tool/provider.py | 4 +- api/core/tools/mcp_tool/tool.py | 2 +- api/core/tools/plugin_tool/provider.py | 4 +- api/core/tools/plugin_tool/tool.py | 2 +- api/core/tools/tool_manager.py | 6 +-- api/core/tools/utils/configuration.py | 2 +- .../tools/utils/dataset_retriever_tool.py | 2 +- api/core/tools/utils/encryption.py | 4 +- api/core/tools/utils/parser.py | 2 +- api/core/tools/utils/yaml_utils.py | 2 +- api/core/tools/workflow_as_tool/tool.py | 2 +- api/core/variables/segments.py | 2 +- api/core/variables/types.py | 2 +- api/core/variables/utils.py | 2 +- .../callbacks/base_workflow_callback.py | 2 +- .../callbacks/workflow_logging_callback.py | 32 ++++++------- .../workflow/conversation_variable_updater.py | 2 +- api/core/workflow/entities/variable_pool.py | 8 ++-- .../entities/workflow_node_execution.py | 2 +- .../workflow/graph_engine/entities/graph.py | 14 +++--- .../entities/runtime_route_state.py | 4 +- .../workflow/graph_engine/graph_engine.py | 8 ++-- api/core/workflow/nodes/agent/agent_node.py | 2 +- api/core/workflow/nodes/answer/answer_node.py | 2 +- .../answer/answer_stream_generate_router.py | 2 +- .../nodes/answer/answer_stream_processor.py | 4 +- .../nodes/answer/base_stream_processor.py | 6 +-- api/core/workflow/nodes/base/entities.py | 2 +- api/core/workflow/nodes/base/node.py | 6 +-- api/core/workflow/nodes/code/code_node.py | 4 +- .../workflow/nodes/document_extractor/node.py | 2 +- api/core/workflow/nodes/end/end_node.py | 2 +- .../nodes/end/end_stream_generate_router.py | 2 +- .../nodes/end/end_stream_processor.py | 4 +- api/core/workflow/nodes/http_request/node.py | 4 +- .../workflow/nodes/if_else/if_else_node.py | 2 +- .../nodes/iteration/iteration_node.py | 4 +- .../nodes/iteration/iteration_start_node.py | 2 +- .../knowledge_retrieval_node.py | 4 +- api/core/workflow/nodes/list_operator/node.py | 2 +- api/core/workflow/nodes/llm/exc.py | 2 +- api/core/workflow/nodes/llm/llm_utils.py | 2 +- api/core/workflow/nodes/llm/node.py | 6 +-- api/core/workflow/nodes/loop/loop_end_node.py | 2 +- api/core/workflow/nodes/loop/loop_node.py | 2 +- .../workflow/nodes/loop/loop_start_node.py | 2 +- .../nodes/parameter_extractor/entities.py | 2 +- .../workflow/nodes/parameter_extractor/exc.py | 2 +- .../parameter_extractor_node.py | 10 ++-- .../question_classifier_node.py | 6 +-- api/core/workflow/nodes/start/start_node.py | 2 +- .../template_transform_node.py | 4 +- api/core/workflow/nodes/tool/tool_node.py | 2 +- .../variable_aggregator_node.py | 2 +- .../nodes/variable_assigner/common/impl.py | 2 +- .../nodes/variable_assigner/v1/node.py | 4 +- .../nodes/variable_assigner/v2/exc.py | 2 +- .../nodes/variable_assigner/v2/node.py | 2 +- .../workflow_execution_repository.py | 2 +- .../workflow_node_execution_repository.py | 2 +- .../utils/variable_template_parser.py | 2 +- api/core/workflow/workflow_cycle_manager.py | 10 ++-- api/core/workflow/workflow_entry.py | 6 +-- api/core/workflow/workflow_type_encoder.py | 2 +- .../update_provider_when_message_created.py | 2 +- api/extensions/ext_database.py | 6 +-- api/extensions/ext_orjson.py | 2 +- .../clickzetta_volume_storage.py | 6 +-- .../clickzetta_volume/file_lifecycle.py | 2 +- .../clickzetta_volume/volume_permissions.py | 2 +- api/extensions/storage/opendal_storage.py | 2 +- api/factories/file_factory.py | 2 +- api/libs/email_i18n.py | 12 ++--- api/libs/external_api.py | 2 +- api/libs/json_in_md_parser.py | 4 +- api/libs/module_loading.py | 5 +- api/models/account.py | 2 +- api/models/model.py | 44 +++++++++--------- api/models/tools.py | 14 +++--- api/models/workflow.py | 10 ++-- .../sqlalchemy_api_workflow_run_repository.py | 2 +- api/services/account_service.py | 34 +++++++------- .../advanced_prompt_template_service.py | 10 ++-- api/services/agent_service.py | 2 +- api/services/annotation_service.py | 8 ++-- api/services/api_based_extension_service.py | 6 +-- api/services/app_dsl_service.py | 4 +- api/services/app_model_config_service.py | 2 +- api/services/app_service.py | 4 +- api/services/auth/api_key_auth_service.py | 2 +- .../clear_free_plan_tenant_expired_logs.py | 4 +- api/services/code_based_extension_service.py | 2 +- api/services/conversation_service.py | 2 +- api/services/dataset_service.py | 2 +- .../entities/model_provider_entities.py | 8 ++-- api/services/errors/llm.py | 2 +- api/services/external_knowledge_service.py | 2 +- api/services/hit_testing_service.py | 4 +- api/services/model_load_balancing_service.py | 14 +++--- api/services/model_provider_service.py | 30 ++++++------ api/services/plugin/data_migration.py | 8 ++-- api/services/plugin/plugin_migration.py | 10 ++-- .../buildin/buildin_retrieval.py | 6 +-- .../database/database_retrieval.py | 4 +- .../recommend_app/recommend_app_base.py | 2 +- .../recommend_app/remote/remote_retrieval.py | 4 +- api/services/recommended_app_service.py | 2 +- api/services/tag_service.py | 8 ++-- .../tools/workflow_tools_manage_service.py | 12 ++--- api/services/website_service.py | 2 +- api/services/workflow/workflow_converter.py | 12 ++--- api/services/workflow_app_service.py | 2 +- .../workflow_draft_variable_service.py | 6 +-- api/services/workflow_service.py | 2 +- api/tasks/delete_conversation_task.py | 2 +- api/tasks/mail_account_deletion_task.py | 4 +- api/tasks/mail_change_mail_task.py | 4 +- api/tasks/mail_email_code_login.py | 2 +- api/tasks/mail_invite_member_task.py | 2 +- api/tasks/mail_owner_transfer_task.py | 6 +-- api/tasks/mail_reset_password_task.py | 2 +- api/tasks/remove_app_and_related_data_task.py | 2 +- api/tasks/workflow_execution_tasks.py | 2 +- api/tasks/workflow_node_execution_tasks.py | 4 +- api/tests/integration_tests/conftest.py | 2 +- .../model_runtime/__mock/plugin_daemon.py | 2 +- .../vdb/__mock/tcvectordb.py | 4 +- .../vdb/test_vector_store.py | 4 +- .../workflow/nodes/__mock/code_executor.py | 2 +- .../workflow/nodes/test_code.py | 10 ++-- .../conftest.py | 4 +- .../test_workflow_response_converter.py | 2 +- .../core/mcp/client/test_session.py | 2 +- .../graph_engine/test_graph_engine.py | 2 +- .../workflow/nodes/test_continue_on_error.py | 4 +- .../core/workflow/nodes/test_if_else.py | 2 +- .../core/workflow/test_variable_pool.py | 2 +- api/tests/unit_tests/libs/test_email_i18n.py | 46 +++++++++---------- api/tests/unit_tests/libs/test_rsa.py | 2 +- api/tests/unit_tests/models/test_account.py | 2 +- 306 files changed, 787 insertions(+), 817 deletions(-) diff --git a/api/configs/remote_settings_sources/base.py b/api/configs/remote_settings_sources/base.py index a96ffdfb4..44ac2acd0 100644 --- a/api/configs/remote_settings_sources/base.py +++ b/api/configs/remote_settings_sources/base.py @@ -11,5 +11,5 @@ class RemoteSettingsSource: def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: raise NotImplementedError - def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any: + def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool): return value diff --git a/api/configs/remote_settings_sources/nacos/__init__.py b/api/configs/remote_settings_sources/nacos/__init__.py index d4fcd2c96..c6efd6f3a 100644 --- a/api/configs/remote_settings_sources/nacos/__init__.py +++ b/api/configs/remote_settings_sources/nacos/__init__.py @@ -33,7 +33,7 @@ class NacosSettingsSource(RemoteSettingsSource): logger.exception("[get-access-token] exception occurred") raise - def _parse_config(self, content: str) -> dict: + def _parse_config(self, content: str): if not content: return {} try: diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 497fd53df..a2cb22601 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -207,7 +207,7 @@ class InstructionGenerationTemplateApi(Resource): @setup_required @login_required @account_initialization_required - def post(self) -> dict: + def post(self): parser = reqparse.RequestParser() parser.add_argument("type", type=str, required=True, default=False, location="json") args = parser.parse_args() diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index a0b73f7e0..5fced3e90 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -1,5 +1,5 @@ import logging -from typing import Any, NoReturn +from typing import NoReturn from flask import Response from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse @@ -29,7 +29,7 @@ from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) -def _convert_values_to_json_serializable_object(value: Segment) -> Any: +def _convert_values_to_json_serializable_object(value: Segment): if isinstance(value, FileSegment): return value.value.model_dump() elif isinstance(value, ArrayFileSegment): @@ -40,7 +40,7 @@ def _convert_values_to_json_serializable_object(value: Segment) -> Any: return value.value -def _serialize_var_value(variable: WorkflowDraftVariable) -> Any: +def _serialize_var_value(variable: WorkflowDraftVariable): value = variable.get_value() # create a copy of the value to avoid affecting the model cache. value = value.model_copy(deep=True) diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index eef9ddc76..43b59d533 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -99,7 +99,7 @@ class MCPAppApi(Resource): return mcp_server, app - def _validate_server_status(self, mcp_server: AppMCPServer) -> None: + def _validate_server_status(self, mcp_server: AppMCPServer): """Validate MCP server status""" if mcp_server.status != AppMCPServerStatus.ACTIVE: raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active") diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index f5e45bcb4..1bcf83de6 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -62,7 +62,7 @@ class BaseAgentRunner(AppRunner): model_instance: ModelInstance, memory: Optional[TokenBufferMemory] = None, prompt_messages: Optional[list[PromptMessage]] = None, - ) -> None: + ): self.tenant_id = tenant_id self.application_generate_entity = application_generate_entity self.conversation = conversation diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 6cb107712..b94a60c40 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -338,7 +338,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): return instruction - def _init_react_state(self, query) -> None: + def _init_react_state(self, query): """ init agent scratchpad """ diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index a31c1050b..816d2782f 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -41,7 +41,7 @@ class AgentScratchpadUnit(BaseModel): action_name: str action_input: Union[dict, str] - def to_dict(self) -> dict: + def to_dict(self): """ Convert to dictionary. """ diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index a5492d70b..fcbf479e2 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -158,7 +158,7 @@ class DatasetConfigManager: return config, ["agent_mode", "dataset_configs", "dataset_query_variable"] @classmethod - def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict) -> dict: + def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict): """ Extract dataset config for legacy compatibility diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 54bca10fc..781a703a0 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -105,7 +105,7 @@ class ModelConfigManager: return dict(config), ["model"] @classmethod - def validate_model_completion_params(cls, cp: dict) -> dict: + def validate_model_completion_params(cls, cp: dict): # model.completion_params if not isinstance(cp, dict): raise ValueError("model.completion_params must be of object type") diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index fa30511f6..e6ab31e58 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -122,7 +122,7 @@ class PromptTemplateConfigManager: return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"] @classmethod - def validate_post_prompt_and_set_defaults(cls, config: dict) -> dict: + def validate_post_prompt_and_set_defaults(cls, config: dict): """ Validate post_prompt and set defaults for prompt feature diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py index cb606953c..e4b308a6f 100644 --- a/api/core/app/apps/advanced_chat/app_config_manager.py +++ b/api/core/app/apps/advanced_chat/app_config_manager.py @@ -41,7 +41,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager): return app_config @classmethod - def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: + def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False): """ Validate for advanced chat app model config diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 561af7bac..84b032d6c 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -481,7 +481,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): message_id: str, context: contextvars.Context, variable_loader: VariableLoader, - ) -> None: + ): """ Generate worker in a new thread. :param flask_app: Flask app diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 5e20e80d1..635754a20 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -54,7 +54,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): workflow: Workflow, system_user_id: str, app: App, - ) -> None: + ): super().__init__( queue_manager=queue_manager, variable_loader=variable_loader, @@ -68,7 +68,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): self.system_user_id = system_user_id self._app = app - def run(self) -> None: + def run(self): app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) @@ -221,7 +221,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): return False - def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None: + def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy): """ Direct output """ diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 750e13c50..8207b70f9 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -101,7 +101,7 @@ class AdvancedChatAppGenerateTaskPipeline: workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, draft_var_saver_factory: DraftVariableSaverFactory, - ) -> None: + ): self._base_task_pipeline = BasedGenerateTaskPipeline( application_generate_entity=application_generate_entity, queue_manager=queue_manager, @@ -289,7 +289,7 @@ class AdvancedChatAppGenerateTaskPipeline: session.rollback() raise - def _ensure_workflow_initialized(self) -> None: + def _ensure_workflow_initialized(self): """Fluent validation for workflow state.""" if not self._workflow_run_id: raise ValueError("workflow run not initialized.") @@ -888,7 +888,7 @@ class AdvancedChatAppGenerateTaskPipeline: if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: + def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None): message = self._get_message(session=session) # If there are assistant files, remove markdown image links from answer diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 55b6ee510..349b58383 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -86,7 +86,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): return app_config @classmethod - def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> dict: + def config_validate(cls, tenant_id: str, config: Mapping[str, Any]): """ Validate for agent chat app model config diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 8665bc9d1..c6d98374c 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -222,7 +222,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): queue_manager: AppQueueManager, conversation_id: str, message_id: str, - ) -> None: + ): """ Generate worker in a new thread. :param flask_app: Flask app diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index d3207365f..388bed525 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -35,7 +35,7 @@ class AgentChatAppRunner(AppRunner): queue_manager: AppQueueManager, conversation: Conversation, message: Message, - ) -> None: + ): """ Run assistant application :param application_generate_entity: application generate entity diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 0eea13516..89a5b8e3b 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -16,7 +16,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = ChatbotAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response @@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index af3731bdc..74c6d2eca 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -94,7 +94,7 @@ class AppGenerateResponseConverter(ABC): return metadata @classmethod - def _error_to_stream_response(cls, e: Exception) -> dict: + def _error_to_stream_response(cls, e: Exception): """ Error to stream response. :param e: exception diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index b420ffb8b..6681fc6e4 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -157,7 +157,7 @@ class BaseAppGenerator: return value - def _sanitize_value(self, value: Any) -> Any: + def _sanitize_value(self, value: Any): if isinstance(value, str): return value.replace("\x00", "") return value diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 6bb513227..795a7beff 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -25,7 +25,7 @@ class PublishFrom(IntEnum): class AppQueueManager: - def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None: + def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom): if not user_id: raise ValueError("user is required") @@ -73,14 +73,14 @@ class AppQueueManager: self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) last_ping_time = elapsed_time // 10 - def stop_listen(self) -> None: + def stop_listen(self): """ Stop listen to queue :return: """ self._q.put(None) - def publish_error(self, e, pub_from: PublishFrom) -> None: + def publish_error(self, e, pub_from: PublishFrom): """ Publish error :param e: error @@ -89,7 +89,7 @@ class AppQueueManager: """ self.publish(QueueErrorEvent(error=e), pub_from) - def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + def publish(self, event: AppQueueEvent, pub_from: PublishFrom): """ Publish event to queue :param event: @@ -100,7 +100,7 @@ class AppQueueManager: self._publish(event, pub_from) @abstractmethod - def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + def _publish(self, event: AppQueueEvent, pub_from: PublishFrom): """ Publish event to queue :param event: @@ -110,7 +110,7 @@ class AppQueueManager: raise NotImplementedError @classmethod - def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None: + def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str): """ Set task stop flag :return: diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 6e8c261a6..dafdcdd42 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -162,7 +162,7 @@ class AppRunner: text: str, stream: bool, usage: Optional[LLMUsage] = None, - ) -> None: + ): """ Direct output :param queue_manager: application queue manager @@ -204,7 +204,7 @@ class AppRunner: queue_manager: AppQueueManager, stream: bool, agent: bool = False, - ) -> None: + ): """ Handle invoke result :param invoke_result: invoke result @@ -220,9 +220,7 @@ class AppRunner: else: raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}") - def _handle_invoke_result_direct( - self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool - ) -> None: + def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool): """ Handle invoke result direct :param invoke_result: invoke result @@ -239,7 +237,7 @@ class AppRunner: def _handle_invoke_result_stream( self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool - ) -> None: + ): """ Handle invoke result :param invoke_result: invoke result diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py index 96dc7dda7..96a3db850 100644 --- a/api/core/app/apps/chat/app_config_manager.py +++ b/api/core/app/apps/chat/app_config_manager.py @@ -81,7 +81,7 @@ class ChatAppConfigManager(BaseAppConfigManager): return app_config @classmethod - def config_validate(cls, tenant_id: str, config: dict) -> dict: + def config_validate(cls, tenant_id: str, config: dict): """ Validate for chat app model config diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index c273776eb..8bd956b31 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -211,7 +211,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): queue_manager: AppQueueManager, conversation_id: str, message_id: str, - ) -> None: + ): """ Generate worker in a new thread. :param flask_app: Flask app diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 4385d0f08..d082cf2d3 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -33,7 +33,7 @@ class ChatAppRunner(AppRunner): queue_manager: AppQueueManager, conversation: Conversation, message: Message, - ) -> None: + ): """ Run application :param application_generate_entity: application generate entity diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index 13a6be167..816d6d79a 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -16,7 +16,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = ChatbotAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response @@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index c8760d3cf..937b2a7dd 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -62,7 +62,7 @@ class WorkflowResponseConverter: *, application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], user: Union[Account, EndUser], - ) -> None: + ): self._application_generate_entity = application_generate_entity self._user = user diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index 02e5d4756..3a1f29689 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -66,7 +66,7 @@ class CompletionAppConfigManager(BaseAppConfigManager): return app_config @classmethod - def config_validate(cls, tenant_id: str, config: dict) -> dict: + def config_validate(cls, tenant_id: str, config: dict): """ Validate for completion app model config diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 8d2f3d488..6e43e5ec9 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -192,7 +192,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message_id: str, - ) -> None: + ): """ Generate worker in a new thread. :param flask_app: Flask app diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index d384bff25..6c4bf4139 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -27,7 +27,7 @@ class CompletionAppRunner(AppRunner): def run( self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message - ) -> None: + ): """ Run application :param application_generate_entity: application generate entity diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index c2b78e817..4d45c6114 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -16,7 +16,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = CompletionAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] + def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response @@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] + def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index 4100a0d5a..67fc016cb 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -14,14 +14,14 @@ from core.app.entities.queue_entities import ( class MessageBasedAppQueueManager(AppQueueManager): def __init__( self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str - ) -> None: + ): super().__init__(task_id, user_id, invoke_from) self._conversation_id = str(conversation_id) self._app_mode = app_mode self._message_id = str(message_id) - def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + def _publish(self, event: AppQueueEvent, pub_from: PublishFrom): """ Publish event to queue :param event: diff --git a/api/core/app/apps/workflow/app_config_manager.py b/api/core/app/apps/workflow/app_config_manager.py index b0aa21c73..e72da91c2 100644 --- a/api/core/app/apps/workflow/app_config_manager.py +++ b/api/core/app/apps/workflow/app_config_manager.py @@ -35,7 +35,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager): return app_config @classmethod - def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: + def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False): """ Validate for workflow app model config diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 22b023460..60395f041 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -435,7 +435,7 @@ class WorkflowAppGenerator(BaseAppGenerator): context: contextvars.Context, variable_loader: VariableLoader, workflow_thread_pool_id: Optional[str] = None, - ) -> None: + ): """ Generate worker in a new thread. :param flask_app: Flask app diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py index 40fc03afb..9985e2d27 100644 --- a/api/core/app/apps/workflow/app_queue_manager.py +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -14,12 +14,12 @@ from core.app.entities.queue_entities import ( class WorkflowAppQueueManager(AppQueueManager): - def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None: + def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str): super().__init__(task_id, user_id, invoke_from) self._app_mode = app_mode - def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + def _publish(self, event: AppQueueEvent, pub_from: PublishFrom): """ Publish event to queue :param event: diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 4f4c1460a..42b357580 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -34,7 +34,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): workflow_thread_pool_id: Optional[str] = None, workflow: Workflow, system_user_id: str, - ) -> None: + ): super().__init__( queue_manager=queue_manager, variable_loader=variable_loader, @@ -45,7 +45,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): self._workflow = workflow self._sys_user_id = system_user_id - def run(self) -> None: + def run(self): """ Run application """ diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index 917ede617..210f6110b 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -17,7 +17,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = WorkflowAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] + def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response @@ -26,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): return dict(blocking_response.to_dict()) @classmethod - def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] + def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index c10f95475..6ab89dbd6 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -92,7 +92,7 @@ class WorkflowAppGenerateTaskPipeline: workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, draft_var_saver_factory: DraftVariableSaverFactory, - ) -> None: + ): self._base_task_pipeline = BasedGenerateTaskPipeline( application_generate_entity=application_generate_entity, queue_manager=queue_manager, @@ -263,7 +263,7 @@ class WorkflowAppGenerateTaskPipeline: session.rollback() raise - def _ensure_workflow_initialized(self) -> None: + def _ensure_workflow_initialized(self): """Fluent validation for workflow state.""" if not self._workflow_run_id: raise ValueError("workflow run not initialized.") @@ -744,7 +744,7 @@ class WorkflowAppGenerateTaskPipeline: if tts_publisher: tts_publisher.publish(None) - def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None: + def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution): invoke_from = self._application_generate_entity.invoke_from if invoke_from == InvokeFrom.SERVICE_API: created_from = WorkflowAppLogCreatedFrom.SERVICE_API diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 948ea95e6..b6cb88ea8 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -74,7 +74,7 @@ class WorkflowBasedAppRunner: queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, app_id: str, - ) -> None: + ): self._queue_manager = queue_manager self._variable_loader = variable_loader self._app_id = app_id @@ -292,7 +292,7 @@ class WorkflowBasedAppRunner: return graph, variable_pool - def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None: + def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent): """ Handle event :param workflow_entry: workflow entry @@ -694,5 +694,5 @@ class WorkflowBasedAppRunner: ) ) - def _publish_event(self, event: AppQueueEvent) -> None: + def _publish_event(self, event: AppQueueEvent): self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index d04855e99..7d98cceb1 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -35,7 +35,7 @@ class BasedGenerateTaskPipeline: application_generate_entity: AppGenerateEntity, queue_manager: AppQueueManager, stream: bool, - ) -> None: + ): self._application_generate_entity = application_generate_entity self.queue_manager = queue_manager self._start_at = time.perf_counter() diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index e3b917067..0dad0a5a9 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -80,7 +80,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): conversation: Conversation, message: Message, stream: bool, - ) -> None: + ): super().__init__( application_generate_entity=application_generate_entity, queue_manager=queue_manager, @@ -362,7 +362,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None: + def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None): """ Save message. :return: @@ -412,7 +412,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): application_generate_entity=self._application_generate_entity, ) - def _handle_stop(self, event: QueueStopEvent) -> None: + def _handle_stop(self, event: QueueStopEvent): """ Handle stop. :return: diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 8ea4a4ec3..e865ba9d6 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -48,7 +48,7 @@ class MessageCycleManager: AdvancedChatAppGenerateEntity, ], task_state: Union[EasyUITaskState, WorkflowTaskState], - ) -> None: + ): self._application_generate_entity = application_generate_entity self._task_state = task_state @@ -132,7 +132,7 @@ class MessageCycleManager: return None - def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None: + def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent): """ Handle retriever resources. :param event: event diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index 65d899a00..30cdab26d 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -23,7 +23,7 @@ def get_colored_text(text: str, color: str) -> str: return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" -def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None) -> None: +def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None): """Print text with highlighting and no end characters.""" text_to_print = get_colored_text(text, color) if color else text print(text_to_print, end=end, file=file) @@ -37,7 +37,7 @@ class DifyAgentCallbackHandler(BaseModel): color: Optional[str] = "" current_loop: int = 1 - def __init__(self, color: Optional[str] = None) -> None: + def __init__(self, color: Optional[str] = None): super().__init__() """Initialize callback handler.""" # use a specific color is not specified @@ -48,7 +48,7 @@ class DifyAgentCallbackHandler(BaseModel): self, tool_name: str, tool_inputs: Mapping[str, Any], - ) -> None: + ): """Do nothing.""" if dify_config.DEBUG: print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color) @@ -61,7 +61,7 @@ class DifyAgentCallbackHandler(BaseModel): message_id: Optional[str] = None, timer: Optional[Any] = None, trace_manager: Optional[TraceQueueManager] = None, - ) -> None: + ): """If not the final action, print out observation.""" if dify_config.DEBUG: print_text("\n[on_tool_end]\n", color=self.color) @@ -82,12 +82,12 @@ class DifyAgentCallbackHandler(BaseModel): ) ) - def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None: + def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any): """Do nothing.""" if dify_config.DEBUG: print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red") - def on_agent_start(self, thought: str) -> None: + def on_agent_start(self, thought: str): """Run on agent start.""" if dify_config.DEBUG: if thought: @@ -98,7 +98,7 @@ class DifyAgentCallbackHandler(BaseModel): else: print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color) - def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any) -> None: + def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any): """Run on agent end.""" if dify_config.DEBUG: print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color) diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index c85d2d599..14d5f38dc 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -21,14 +21,14 @@ class DatasetIndexToolCallbackHandler: def __init__( self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom - ) -> None: + ): self._queue_manager = queue_manager self._app_id = app_id self._message_id = message_id self._user_id = user_id self._invoke_from = invoke_from - def on_query(self, query: str, dataset_id: str) -> None: + def on_query(self, query: str, dataset_id: str): """ Handle query. """ @@ -46,7 +46,7 @@ class DatasetIndexToolCallbackHandler: db.session.add(dataset_query) db.session.commit() - def on_tool_end(self, documents: list[Document]) -> None: + def on_tool_end(self, documents: list[Document]): """Handle tool end.""" for document in documents: if document.metadata is not None: diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index ac64a8e3a..0fd49b059 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -33,7 +33,7 @@ class SimpleModelProviderEntity(BaseModel): icon_large: Optional[I18nObject] = None supported_model_types: list[ModelType] - def __init__(self, provider_entity: ProviderEntity) -> None: + def __init__(self, provider_entity: ProviderEntity): """ Init simple provider. @@ -57,7 +57,7 @@ class ProviderModelWithStatusEntity(ProviderModel): load_balancing_enabled: bool = False has_invalid_load_balancing_configs: bool = False - def raise_for_status(self) -> None: + def raise_for_status(self): """ Check model status and raise ValueError if not active. diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 9119462ac..61a960c3d 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -280,9 +280,7 @@ class ProviderConfiguration(BaseModel): else [], ) - def validate_provider_credentials( - self, credentials: dict, credential_id: str = "", session: Session | None = None - ) -> dict: + def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None): """ Validate custom credentials. :param credentials: provider credentials @@ -291,7 +289,7 @@ class ProviderConfiguration(BaseModel): :return: """ - def _validate(s: Session) -> dict: + def _validate(s: Session): # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( self.provider.provider_credential_schema.credential_form_schemas @@ -402,7 +400,7 @@ class ProviderConfiguration(BaseModel): logger.warning("Error generating next credential name: %s", str(e)) return "API KEY 1" - def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None: + def create_provider_credential(self, credentials: dict, credential_name: str | None): """ Add custom provider credentials. :param credentials: provider credentials @@ -458,7 +456,7 @@ class ProviderConfiguration(BaseModel): credentials: dict, credential_id: str, credential_name: str | None, - ) -> None: + ): """ update a saved provider credential (by credential_id). @@ -519,7 +517,7 @@ class ProviderConfiguration(BaseModel): credential_record: ProviderCredential | ProviderModelCredential, credential_source: str, session: Session, - ) -> None: + ): """ Update load balancing configurations that reference the given credential_id. @@ -559,7 +557,7 @@ class ProviderConfiguration(BaseModel): session.commit() - def delete_provider_credential(self, credential_id: str) -> None: + def delete_provider_credential(self, credential_id: str): """ Delete a saved provider credential (by credential_id). @@ -636,7 +634,7 @@ class ProviderConfiguration(BaseModel): session.rollback() raise - def switch_active_provider_credential(self, credential_id: str) -> None: + def switch_active_provider_credential(self, credential_id: str): """ Switch active provider credential (copy the selected one into current active snapshot). @@ -814,7 +812,7 @@ class ProviderConfiguration(BaseModel): credentials: dict, credential_id: str = "", session: Session | None = None, - ) -> dict: + ): """ Validate custom model credentials. @@ -825,7 +823,7 @@ class ProviderConfiguration(BaseModel): :return: """ - def _validate(s: Session) -> dict: + def _validate(s: Session): # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( self.provider.model_credential_schema.credential_form_schemas @@ -1009,7 +1007,7 @@ class ProviderConfiguration(BaseModel): session.rollback() raise - def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None: + def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str): """ Delete a saved provider credential (by credential_id). @@ -1079,7 +1077,7 @@ class ProviderConfiguration(BaseModel): session.rollback() raise - def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str) -> None: + def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str): """ if model list exist this custom model, switch the custom model credential. if model list not exist this custom model, use the credential to add a new custom model record. @@ -1122,7 +1120,7 @@ class ProviderConfiguration(BaseModel): session.add(provider_model_record) session.commit() - def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None: + def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str): """ switch the custom model credential. @@ -1152,7 +1150,7 @@ class ProviderConfiguration(BaseModel): session.add(provider_model_record) session.commit() - def delete_custom_model(self, model_type: ModelType, model: str) -> None: + def delete_custom_model(self, model_type: ModelType, model: str): """ Delete custom model. :param model_type: model type @@ -1347,7 +1345,7 @@ class ProviderConfiguration(BaseModel): provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) - def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None) -> None: + def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None): """ Switch preferred provider type. :param provider_type: @@ -1359,7 +1357,7 @@ class ProviderConfiguration(BaseModel): if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled: return - def _switch(s: Session) -> None: + def _switch(s: Session): # get preferred provider model_provider_id = ModelProviderID(self.provider.provider) provider_names = [self.provider.provider] @@ -1403,7 +1401,7 @@ class ProviderConfiguration(BaseModel): return secret_input_form_variables - def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict: + def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]): """ Obfuscated credentials. diff --git a/api/core/errors/error.py b/api/core/errors/error.py index ad921bc25..642f24a41 100644 --- a/api/core/errors/error.py +++ b/api/core/errors/error.py @@ -6,7 +6,7 @@ class LLMError(ValueError): description: Optional[str] = None - def __init__(self, description: Optional[str] = None) -> None: + def __init__(self, description: Optional[str] = None): self.description = description diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index 4423299f7..fab9ae44e 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -10,11 +10,11 @@ class APIBasedExtensionRequestor: timeout: tuple[int, int] = (5, 60) """timeout for request connect and read""" - def __init__(self, api_endpoint: str, api_key: str) -> None: + def __init__(self, api_endpoint: str, api_key: str): self.api_endpoint = api_endpoint self.api_key = api_key - def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: + def request(self, point: APIBasedExtensionPoint, params: dict): """ Request the api. diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 1c4fa60ab..eee914a52 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -34,7 +34,7 @@ class Extensible: tenant_id: str config: Optional[dict] = None - def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None: + def __init__(self, tenant_id: str, config: Optional[dict] = None): self.tenant_id = tenant_id self.config = config diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 2100e7fad..45878e763 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -18,7 +18,7 @@ class ApiExternalDataTool(ExternalDataTool): """the unique name of external data tool""" @classmethod - def validate_config(cls, tenant_id: str, config: dict) -> None: + def validate_config(cls, tenant_id: str, config: dict): """ Validate the incoming form config data. diff --git a/api/core/external_data_tool/base.py b/api/core/external_data_tool/base.py index 0db736f09..81f1aaf17 100644 --- a/api/core/external_data_tool/base.py +++ b/api/core/external_data_tool/base.py @@ -16,14 +16,14 @@ class ExternalDataTool(Extensible, ABC): variable: str """the tool variable name of app tool""" - def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None) -> None: + def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None): super().__init__(tenant_id, config) self.app_id = app_id self.variable = variable @classmethod @abstractmethod - def validate_config(cls, tenant_id: str, config: dict) -> None: + def validate_config(cls, tenant_id: str, config: dict): """ Validate the incoming form config data. diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py index 75a638acb..538bc3f52 100644 --- a/api/core/external_data_tool/factory.py +++ b/api/core/external_data_tool/factory.py @@ -6,14 +6,14 @@ from extensions.ext_code_based_extension import code_based_extension class ExternalDataToolFactory: - def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None: + def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict): extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) self.__extension_instance = extension_class( tenant_id=tenant_id, app_id=app_id, variable=variable, config=config ) @classmethod - def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: + def validate_config(cls, name: str, tenant_id: str, config: dict): """ Validate the incoming form config data. diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py index fac68beb0..4c8e7282b 100644 --- a/api/core/file/tool_file_parser.py +++ b/api/core/file/tool_file_parser.py @@ -7,6 +7,6 @@ if TYPE_CHECKING: _tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None -def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None: +def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]): global _tool_file_manager_factory _tool_file_manager_factory = factory diff --git a/api/core/helper/code_executor/code_node_provider.py b/api/core/helper/code_executor/code_node_provider.py index e233a596b..701208080 100644 --- a/api/core/helper/code_executor/code_node_provider.py +++ b/api/core/helper/code_executor/code_node_provider.py @@ -22,7 +22,7 @@ class CodeNodeProvider(BaseModel): pass @classmethod - def get_default_config(cls) -> dict: + def get_default_config(cls): return { "type": "code", "config": { diff --git a/api/core/helper/code_executor/jinja2/jinja2_transformer.py b/api/core/helper/code_executor/jinja2/jinja2_transformer.py index 54c78cdf9..969125d2f 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_transformer.py +++ b/api/core/helper/code_executor/jinja2/jinja2_transformer.py @@ -5,7 +5,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer class Jinja2TemplateTransformer(TemplateTransformer): @classmethod - def transform_response(cls, response: str) -> dict: + def transform_response(cls, response: str): """ Transform response to dict :param response: response diff --git a/api/core/helper/code_executor/python3/python3_code_provider.py b/api/core/helper/code_executor/python3/python3_code_provider.py index 9cca8af7c..151bf0e20 100644 --- a/api/core/helper/code_executor/python3/python3_code_provider.py +++ b/api/core/helper/code_executor/python3/python3_code_provider.py @@ -13,7 +13,7 @@ class Python3CodeProvider(CodeNodeProvider): def get_default_code(cls) -> str: return dedent( """ - def main(arg1: str, arg2: str) -> dict: + def main(arg1: str, arg2: str): return { "result": arg1 + arg2, } diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py index 35349210b..1c112007c 100644 --- a/api/core/helper/model_provider_cache.py +++ b/api/core/helper/model_provider_cache.py @@ -34,7 +34,7 @@ class ProviderCredentialsCache: else: return None - def set(self, credentials: dict) -> None: + def set(self, credentials: dict): """ Cache model provider credentials. @@ -43,7 +43,7 @@ class ProviderCredentialsCache: """ redis_client.setex(self.cache_key, 86400, json.dumps(credentials)) - def delete(self) -> None: + def delete(self): """ Delete cached model provider credentials. diff --git a/api/core/helper/provider_cache.py b/api/core/helper/provider_cache.py index 48ec3be5c..26e738fce 100644 --- a/api/core/helper/provider_cache.py +++ b/api/core/helper/provider_cache.py @@ -28,11 +28,11 @@ class ProviderCredentialsCache(ABC): return None return None - def set(self, config: dict[str, Any]) -> None: + def set(self, config: dict[str, Any]): """Cache provider credentials""" redis_client.setex(self.cache_key, 86400, json.dumps(config)) - def delete(self) -> None: + def delete(self): """Delete cached provider credentials""" redis_client.delete(self.cache_key) @@ -75,10 +75,10 @@ class NoOpProviderCredentialCache: """Get cached provider credentials""" return None - def set(self, config: dict[str, Any]) -> None: + def set(self, config: dict[str, Any]): """Cache provider credentials""" pass - def delete(self) -> None: + def delete(self): """Delete cached provider credentials""" pass diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py index 918b3e9ee..95a1086ca 100644 --- a/api/core/helper/tool_parameter_cache.py +++ b/api/core/helper/tool_parameter_cache.py @@ -37,11 +37,11 @@ class ToolParameterCache: else: return None - def set(self, parameters: dict) -> None: + def set(self, parameters: dict): """Cache model provider credentials.""" redis_client.setex(self.cache_key, 86400, json.dumps(parameters)) - def delete(self) -> None: + def delete(self): """ Delete cached model provider credentials. diff --git a/api/core/helper/trace_id_helper.py b/api/core/helper/trace_id_helper.py index 5cd0ea5c6..35e6e292d 100644 --- a/api/core/helper/trace_id_helper.py +++ b/api/core/helper/trace_id_helper.py @@ -49,7 +49,7 @@ def get_external_trace_id(request: Any) -> Optional[str]: return None -def extract_external_trace_id_from_args(args: Mapping[str, Any]) -> dict: +def extract_external_trace_id_from_args(args: Mapping[str, Any]): """ Extract 'external_trace_id' from args. diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 20d98562d..a5d7f7aac 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -44,11 +44,11 @@ class HostingConfiguration: provider_map: dict[str, HostingProvider] moderation_config: Optional[HostedModerationConfig] = None - def __init__(self) -> None: + def __init__(self): self.provider_map = {} self.moderation_config = None - def init_app(self, app: Flask) -> None: + def init_app(self, app: Flask): if dify_config.EDITION != "CLOUD": return diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 37eb3eab6..89a05e02c 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -512,7 +512,7 @@ class IndexingRunner: dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document], - ) -> None: + ): """ insert index and update document/segment status to completed """ @@ -651,7 +651,7 @@ class IndexingRunner: @staticmethod def _update_document_index_status( document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None - ) -> None: + ): """ Update the document indexing status. """ @@ -670,7 +670,7 @@ class IndexingRunner: db.session.commit() @staticmethod - def _update_segments_by_document(dataset_document_id: str, update_params: dict) -> None: + def _update_segments_by_document(dataset_document_id: str, update_params: dict): """ Update the document segment by document id. """ diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 894f090c1..94b8258e9 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -127,7 +127,7 @@ class LLMGenerator: return questions @classmethod - def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool) -> dict: + def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool): output_parser = RuleConfigGeneratorOutputParser() error = "" @@ -262,9 +262,7 @@ class LLMGenerator: return rule_config @classmethod - def generate_code( - cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript" - ) -> dict: + def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"): if code_language == "python": prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) else: @@ -373,7 +371,7 @@ class LLMGenerator: @staticmethod def instruction_modify_legacy( tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None - ) -> dict: + ): last_run: Message | None = ( db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() ) @@ -413,7 +411,7 @@ class LLMGenerator: instruction: str, model_config: dict, ideal_output: str | None, - ) -> dict: + ): from services.workflow_service import WorkflowService app: App | None = db.session.query(App).where(App.id == flow_id).first() @@ -451,7 +449,7 @@ class LLMGenerator: return [] parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log) - def dict_of_event(event: AgentLogEvent) -> dict: + def dict_of_event(event: AgentLogEvent): return { "status": event.status, "error": event.error, @@ -488,7 +486,7 @@ class LLMGenerator: instruction: str, node_type: str, ideal_output: str | None, - ) -> dict: + ): LAST_RUN = "{{#last_run#}}" CURRENT = "{{#current#}}" ERROR_MESSAGE = "{{#error_message#}}" diff --git a/api/core/llm_generator/output_parser/rule_config_generator.py b/api/core/llm_generator/output_parser/rule_config_generator.py index 0c7683b16..95fc6dbec 100644 --- a/api/core/llm_generator/output_parser/rule_config_generator.py +++ b/api/core/llm_generator/output_parser/rule_config_generator.py @@ -1,5 +1,3 @@ -from typing import Any - from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.prompts import ( RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, @@ -17,7 +15,7 @@ class RuleConfigGeneratorOutputParser: RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE, ) - def parse(self, text: str) -> Any: + def parse(self, text: str): try: expected_keys = ["prompt", "variables", "opening_statement"] parsed = parse_and_check_json_markdown(text, expected_keys) diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 151cef1bc..28833fe8e 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -210,7 +210,7 @@ def _handle_native_json_schema( structured_output_schema: Mapping, model_parameters: dict, rules: list[ParameterRule], -) -> dict: +): """ Handle structured output for models with native JSON schema support. @@ -232,7 +232,7 @@ def _handle_native_json_schema( return model_parameters -def _set_response_format(model_parameters: dict, rules: list) -> None: +def _set_response_format(model_parameters: dict, rules: list): """ Set the appropriate response format parameter based on model rules. @@ -306,7 +306,7 @@ def _parse_structured_output(result_text: str) -> Mapping[str, Any]: return structured_output -def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping) -> dict: +def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping): """ Prepare JSON schema based on model requirements. @@ -334,7 +334,7 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema return {"schema": processed_schema, "name": "llm_response"} -def remove_additional_properties(schema: dict) -> None: +def remove_additional_properties(schema: dict): """ Remove additionalProperties fields from JSON schema. Used for models like Gemini that don't support this property. @@ -357,7 +357,7 @@ def remove_additional_properties(schema: dict) -> None: remove_additional_properties(item) -def convert_boolean_to_string(schema: dict) -> None: +def convert_boolean_to_string(schema: dict): """ Convert boolean type specifications to string in JSON schema. diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py index 98cdc4c8b..e78859cc1 100644 --- a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -1,6 +1,5 @@ import json import re -from typing import Any from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT @@ -9,7 +8,7 @@ class SuggestedQuestionsAfterAnswerOutputParser: def get_format_instructions(self) -> str: return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT - def parse(self, text: str) -> Any: + def parse(self, text: str): action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL) if action_match is not None: json_obj = json.loads(action_match.group(0).strip()) diff --git a/api/core/mcp/auth/auth_provider.py b/api/core/mcp/auth/auth_provider.py index bad99fc09..bf1820f74 100644 --- a/api/core/mcp/auth/auth_provider.py +++ b/api/core/mcp/auth/auth_provider.py @@ -44,7 +44,7 @@ class OAuthClientProvider: return None return OAuthClientInformation.model_validate(client_information) - def save_client_information(self, client_information: OAuthClientInformationFull) -> None: + def save_client_information(self, client_information: OAuthClientInformationFull): """Saves client information after dynamic registration.""" MCPToolManageService.update_mcp_provider_credentials( self.mcp_provider, @@ -63,13 +63,13 @@ class OAuthClientProvider: refresh_token=credentials.get("refresh_token", ""), ) - def save_tokens(self, tokens: OAuthTokens) -> None: + def save_tokens(self, tokens: OAuthTokens): """Stores new OAuth tokens for the current session.""" # update mcp provider credentials token_dict = tokens.model_dump() MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True) - def save_code_verifier(self, code_verifier: str) -> None: + def save_code_verifier(self, code_verifier: str): """Saves a PKCE code verifier for the current session.""" MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier}) diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py index cc38954ec..cc4263c0a 100644 --- a/api/core/mcp/client/sse_client.py +++ b/api/core/mcp/client/sse_client.py @@ -47,7 +47,7 @@ class SSETransport: headers: dict[str, Any] | None = None, timeout: float = 5.0, sse_read_timeout: float = 5 * 60, - ) -> None: + ): """Initialize the SSE transport. Args: @@ -76,7 +76,7 @@ class SSETransport: return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme - def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None: + def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue): """Handle an 'endpoint' SSE event. Args: @@ -94,7 +94,7 @@ class SSETransport: status_queue.put(_StatusReady(endpoint_url)) - def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None: + def _handle_message_event(self, sse_data: str, read_queue: ReadQueue): """Handle a 'message' SSE event. Args: @@ -110,7 +110,7 @@ class SSETransport: logger.exception("Error parsing server message") read_queue.put(exc) - def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue) -> None: + def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue): """Handle a single SSE event. Args: @@ -126,7 +126,7 @@ class SSETransport: case _: logger.warning("Unknown SSE event: %s", sse.event) - def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue) -> None: + def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue): """Read and process SSE events. Args: @@ -144,7 +144,7 @@ class SSETransport: finally: read_queue.put(None) - def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None: + def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage): """Send a single message to the server. Args: @@ -163,7 +163,7 @@ class SSETransport: response.raise_for_status() logger.debug("Client message sent successfully: %s", response.status_code) - def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None: + def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue): """Handle writing messages to the server. Args: @@ -303,7 +303,7 @@ def sse_client( write_queue.put(None) -def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None: +def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage): """ Send a message to the server using the provided HTTP client. diff --git a/api/core/mcp/client/streamable_client.py b/api/core/mcp/client/streamable_client.py index a2b003e71..7eafa7983 100644 --- a/api/core/mcp/client/streamable_client.py +++ b/api/core/mcp/client/streamable_client.py @@ -82,7 +82,7 @@ class StreamableHTTPTransport: headers: dict[str, Any] | None = None, timeout: float | timedelta = 30, sse_read_timeout: float | timedelta = 60 * 5, - ) -> None: + ): """Initialize the StreamableHTTP transport. Args: @@ -122,7 +122,7 @@ class StreamableHTTPTransport: def _maybe_extract_session_id_from_response( self, response: httpx.Response, - ) -> None: + ): """Extract and store session ID from response headers.""" new_session_id = response.headers.get(MCP_SESSION_ID) if new_session_id: @@ -173,7 +173,7 @@ class StreamableHTTPTransport: self, client: httpx.Client, server_to_client_queue: ServerToClientQueue, - ) -> None: + ): """Handle GET stream for server-initiated messages.""" try: if not self.session_id: @@ -197,7 +197,7 @@ class StreamableHTTPTransport: except Exception as exc: logger.debug("GET stream error (non-fatal): %s", exc) - def _handle_resumption_request(self, ctx: RequestContext) -> None: + def _handle_resumption_request(self, ctx: RequestContext): """Handle a resumption request using GET with SSE.""" headers = self._update_headers_with_session(ctx.headers) if ctx.metadata and ctx.metadata.resumption_token: @@ -230,7 +230,7 @@ class StreamableHTTPTransport: if is_complete: break - def _handle_post_request(self, ctx: RequestContext) -> None: + def _handle_post_request(self, ctx: RequestContext): """Handle a POST request with response processing.""" headers = self._update_headers_with_session(ctx.headers) message = ctx.session_message.message @@ -278,7 +278,7 @@ class StreamableHTTPTransport: self, response: httpx.Response, server_to_client_queue: ServerToClientQueue, - ) -> None: + ): """Handle JSON response from the server.""" try: content = response.read() @@ -288,7 +288,7 @@ class StreamableHTTPTransport: except Exception as exc: server_to_client_queue.put(exc) - def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None: + def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext): """Handle SSE response from the server.""" try: event_source = EventSource(response) @@ -307,7 +307,7 @@ class StreamableHTTPTransport: self, content_type: str, server_to_client_queue: ServerToClientQueue, - ) -> None: + ): """Handle unexpected content type in response.""" error_msg = f"Unexpected content type: {content_type}" logger.error(error_msg) @@ -317,7 +317,7 @@ class StreamableHTTPTransport: self, server_to_client_queue: ServerToClientQueue, request_id: RequestId, - ) -> None: + ): """Send a session terminated error response.""" jsonrpc_error = JSONRPCError( jsonrpc="2.0", @@ -333,7 +333,7 @@ class StreamableHTTPTransport: client_to_server_queue: ClientToServerQueue, server_to_client_queue: ServerToClientQueue, start_get_stream: Callable[[], None], - ) -> None: + ): """Handle writing requests to the server. This method processes messages from the client_to_server_queue and sends them to the server. @@ -379,7 +379,7 @@ class StreamableHTTPTransport: except Exception as exc: server_to_client_queue.put(exc) - def terminate_session(self, client: httpx.Client) -> None: + def terminate_session(self, client: httpx.Client): """Terminate the session by sending a DELETE request.""" if not self.session_id: return @@ -441,7 +441,7 @@ def streamablehttp_client( timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), ) as client: # Define callbacks that need access to thread pool - def start_get_stream() -> None: + def start_get_stream(): """Start a worker thread to handle server-initiated messages.""" executor.submit(transport.handle_get_stream, client, server_to_client_queue) diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 1bd533581..96c48034c 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -76,7 +76,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): ReceiveNotificationT ]""", on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], - ) -> None: + ): self.request_id = request_id self.request_meta = request_meta self.request = request @@ -95,7 +95,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, - ) -> None: + ): """Exit the context manager, performing cleanup and notifying completion.""" try: if self._completed: @@ -103,7 +103,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): finally: self._entered = False - def respond(self, response: SendResultT | ErrorData) -> None: + def respond(self, response: SendResultT | ErrorData): """Send a response for this request. Must be called within a context manager block. @@ -119,7 +119,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): self._session._send_response(request_id=self.request_id, response=response) - def cancel(self) -> None: + def cancel(self): """Cancel this request and mark it as completed.""" if not self._entered: raise RuntimeError("RequestResponder must be used as a context manager") @@ -163,7 +163,7 @@ class BaseSession( receive_notification_type: type[ReceiveNotificationT], # If none, reading will never time out read_timeout_seconds: timedelta | None = None, - ) -> None: + ): self._read_stream = read_stream self._write_stream = write_stream self._response_streams = {} @@ -183,7 +183,7 @@ class BaseSession( self._receiver_future = self._executor.submit(self._receive_loop) return self - def check_receiver_status(self) -> None: + def check_receiver_status(self): """`check_receiver_status` ensures that any exceptions raised during the execution of `_receive_loop` are retrieved and propagated.""" if self._receiver_future and self._receiver_future.done(): @@ -191,7 +191,7 @@ class BaseSession( def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None - ) -> None: + ): self._read_stream.put(None) self._write_stream.put(None) @@ -277,7 +277,7 @@ class BaseSession( self, notification: SendNotificationT, related_request_id: RequestId | None = None, - ) -> None: + ): """ Emits a notification, which is a one-way message that does not expect a response. @@ -296,7 +296,7 @@ class BaseSession( ) self._write_stream.put(session_message) - def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: + def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData): if isinstance(response, ErrorData): jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) @@ -310,7 +310,7 @@ class BaseSession( session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response)) self._write_stream.put(session_message) - def _receive_loop(self) -> None: + def _receive_loop(self): """ Main message processing loop. In a real synchronous implementation, this would likely run in a separate thread. @@ -382,7 +382,7 @@ class BaseSession( logger.exception("Error in message processing loop") raise - def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: + def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]): """ Can be overridden by subclasses to handle a request without needing to listen on the message stream. @@ -391,15 +391,13 @@ class BaseSession( forwarded on to the message stream. """ - def _received_notification(self, notification: ReceiveNotificationT) -> None: + def _received_notification(self, notification: ReceiveNotificationT): """ Can be overridden by subclasses to handle a notification without needing to listen on the message stream. """ - def send_progress_notification( - self, progress_token: str | int, progress: float, total: float | None = None - ) -> None: + def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None): """ Sends a progress notification for a request that is currently being processed. @@ -408,5 +406,5 @@ class BaseSession( def _handle_incoming( self, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, - ) -> None: + ): """A generic handler for incoming messages. Overwritten by subclasses.""" diff --git a/api/core/mcp/session/client_session.py b/api/core/mcp/session/client_session.py index 1bccf1d03..5817416ba 100644 --- a/api/core/mcp/session/client_session.py +++ b/api/core/mcp/session/client_session.py @@ -28,19 +28,19 @@ class LoggingFnT(Protocol): def __call__( self, params: types.LoggingMessageNotificationParams, - ) -> None: ... + ): ... class MessageHandlerFnT(Protocol): def __call__( self, message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, - ) -> None: ... + ): ... def _default_message_handler( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, -) -> None: +): if isinstance(message, Exception): raise ValueError(str(message)) elif isinstance(message, (types.ServerNotification | RequestResponder)): @@ -68,7 +68,7 @@ def _default_list_roots_callback( def _default_logging_callback( params: types.LoggingMessageNotificationParams, -) -> None: +): pass @@ -94,7 +94,7 @@ class ClientSession( logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, - ) -> None: + ): super().__init__( read_stream, write_stream, @@ -155,9 +155,7 @@ class ClientSession( types.EmptyResult, ) - def send_progress_notification( - self, progress_token: str | int, progress: float, total: float | None = None - ) -> None: + def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None): """Send a progress notification.""" self.send_notification( types.ClientNotification( @@ -314,7 +312,7 @@ class ClientSession( types.ListToolsResult, ) - def send_roots_list_changed(self) -> None: + def send_roots_list_changed(self): """Send a roots/list_changed notification.""" self.send_notification( types.ClientNotification( @@ -324,7 +322,7 @@ class ClientSession( ) ) - def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: + def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]): ctx = RequestContext[ClientSession, Any]( request_id=responder.request_id, meta=responder.request_meta, @@ -352,11 +350,11 @@ class ClientSession( def _handle_incoming( self, req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, - ) -> None: + ): """Handle incoming messages by forwarding to the message handler.""" self._message_handler(req) - def _received_notification(self, notification: types.ServerNotification) -> None: + def _received_notification(self, notification: types.ServerNotification): """Handle notifications from the server.""" # Process specific notification types match notification.root: diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 17050fcad..f2178b027 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -27,7 +27,7 @@ class TokenBufferMemory: self, conversation: Conversation, model_instance: ModelInstance, - ) -> None: + ): self.conversation = conversation self.model_instance = model_instance diff --git a/api/core/model_manager.py b/api/core/model_manager.py index e56756554..a59b0ae82 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -32,7 +32,7 @@ class ModelInstance: Model instance class """ - def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None: + def __init__(self, provider_model_bundle: ProviderModelBundle, model: str): self.provider_model_bundle = provider_model_bundle self.model = model self.provider = provider_model_bundle.configuration.provider.provider @@ -46,7 +46,7 @@ class ModelInstance: ) @staticmethod - def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict: + def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str): """ Fetch credentials from provider model bundle :param provider_model_bundle: provider model bundle @@ -342,7 +342,7 @@ class ModelInstance: ), ) - def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs) -> Any: + def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs): """ Round-robin invoke :param function: function to invoke @@ -379,7 +379,7 @@ class ModelInstance: except Exception as e: raise e - def get_tts_voices(self, language: Optional[str] = None) -> list: + def get_tts_voices(self, language: Optional[str] = None): """ Invoke large language tts model voices @@ -394,7 +394,7 @@ class ModelInstance: class ModelManager: - def __init__(self) -> None: + def __init__(self): self._provider_manager = ProviderManager() def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance: @@ -453,7 +453,7 @@ class LBModelManager: model: str, load_balancing_configs: list[ModelLoadBalancingConfiguration], managed_credentials: Optional[dict] = None, - ) -> None: + ): """ Load balancing model manager :param tenant_id: tenant_id @@ -534,7 +534,7 @@ model: %s""", return config - def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None: + def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60): """ Cooldown model load balancing config :param config: model load balancing config diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py index 57cad1728..5ce4c23db 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/core/model_runtime/callbacks/base_callback.py @@ -35,7 +35,7 @@ class Callback(ABC): stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, - ) -> None: + ): """ Before invoke callback @@ -94,7 +94,7 @@ class Callback(ABC): stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, - ) -> None: + ): """ After invoke callback @@ -124,7 +124,7 @@ class Callback(ABC): stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, - ) -> None: + ): """ Invoke error callback @@ -141,7 +141,7 @@ class Callback(ABC): """ raise NotImplementedError() - def print_text(self, text: str, color: Optional[str] = None, end: str = "") -> None: + def print_text(self, text: str, color: Optional[str] = None, end: str = ""): """Print text with highlighting and no end characters.""" text_to_print = self._get_colored_text(text, color) if color else text print(text_to_print, end=end) diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/core/model_runtime/callbacks/logging_callback.py index 899f08195..8411afca9 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/core/model_runtime/callbacks/logging_callback.py @@ -24,7 +24,7 @@ class LoggingCallback(Callback): stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, - ) -> None: + ): """ Before invoke callback @@ -110,7 +110,7 @@ class LoggingCallback(Callback): stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, - ) -> None: + ): """ After invoke callback @@ -151,7 +151,7 @@ class LoggingCallback(Callback): stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, - ) -> None: + ): """ Invoke error callback diff --git a/api/core/model_runtime/errors/invoke.py b/api/core/model_runtime/errors/invoke.py index 767542536..6bcb70768 100644 --- a/api/core/model_runtime/errors/invoke.py +++ b/api/core/model_runtime/errors/invoke.py @@ -6,7 +6,7 @@ class InvokeError(ValueError): description: Optional[str] = None - def __init__(self, description: Optional[str] = None) -> None: + def __init__(self, description: Optional[str] = None): self.description = description def __str__(self): diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 7d5ce1e47..f41818e27 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -239,7 +239,7 @@ class AIModel(BaseModel): """ return None - def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName) -> dict: + def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName): """ Get default parameter rule for given name diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index ce378b443..24b206fdb 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -408,7 +408,7 @@ class LargeLanguageModel(AIModel): stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, - ) -> None: + ): """ Trigger before invoke callbacks @@ -456,7 +456,7 @@ class LargeLanguageModel(AIModel): stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, - ) -> None: + ): """ Trigger new chunk callbacks @@ -503,7 +503,7 @@ class LargeLanguageModel(AIModel): stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, - ) -> None: + ): """ Trigger after invoke callbacks @@ -553,7 +553,7 @@ class LargeLanguageModel(AIModel): stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, - ) -> None: + ): """ Trigger invoke error callbacks diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py index cb740c1fd..8f8a638af 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py @@ -28,7 +28,7 @@ class GPT2Tokenizer: return GPT2Tokenizer._get_num_tokens_by_gpt2(text) @staticmethod - def get_encoder() -> Any: + def get_encoder(): global _tokenizer, _lock if _tokenizer is not None: return _tokenizer diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py index d51831900..9ee29f2f2 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/core/model_runtime/model_providers/__base/tts_model.py @@ -56,7 +56,7 @@ class TTSModel(AIModel): except Exception as e: raise self._transform_invoke_error(e) - def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list[dict]: + def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None): """ Retrieves the list of voices supported by a given text-to-speech (TTS) model. diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 24cf69a50..6502b920f 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -36,7 +36,7 @@ class ModelProviderExtension(BaseModel): class ModelProviderFactory: provider_position_map: dict[str, int] - def __init__(self, tenant_id: str) -> None: + def __init__(self, tenant_id: str): self.provider_position_map = {} self.tenant_id = tenant_id @@ -132,7 +132,7 @@ class ModelProviderFactory: return plugin_model_provider_entity - def provider_credentials_validate(self, *, provider: str, credentials: dict) -> dict: + def provider_credentials_validate(self, *, provider: str, credentials: dict): """ Validate provider credentials @@ -163,9 +163,7 @@ class ModelProviderFactory: return filtered_credentials - def model_credentials_validate( - self, *, provider: str, model_type: ModelType, model: str, credentials: dict - ) -> dict: + def model_credentials_validate(self, *, provider: str, model_type: ModelType, model: str, credentials: dict): """ Validate model credentials diff --git a/api/core/model_runtime/schema_validators/common_validator.py b/api/core/model_runtime/schema_validators/common_validator.py index b68900740..2caedeaf4 100644 --- a/api/core/model_runtime/schema_validators/common_validator.py +++ b/api/core/model_runtime/schema_validators/common_validator.py @@ -6,7 +6,7 @@ from core.model_runtime.entities.provider_entities import CredentialFormSchema, class CommonValidator: def _validate_and_filter_credential_form_schemas( self, credential_form_schemas: list[CredentialFormSchema], credentials: dict - ) -> dict: + ): need_validate_credential_form_schema_map = {} for credential_form_schema in credential_form_schemas: if not credential_form_schema.show_on: diff --git a/api/core/model_runtime/schema_validators/model_credential_schema_validator.py b/api/core/model_runtime/schema_validators/model_credential_schema_validator.py index 7d1644d13..0ac935ca3 100644 --- a/api/core/model_runtime/schema_validators/model_credential_schema_validator.py +++ b/api/core/model_runtime/schema_validators/model_credential_schema_validator.py @@ -8,7 +8,7 @@ class ModelCredentialSchemaValidator(CommonValidator): self.model_type = model_type self.model_credential_schema = model_credential_schema - def validate_and_filter(self, credentials: dict) -> dict: + def validate_and_filter(self, credentials: dict): """ Validate model credentials diff --git a/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py b/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py index 6dff2428c..06350f92a 100644 --- a/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py +++ b/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py @@ -6,7 +6,7 @@ class ProviderCredentialSchemaValidator(CommonValidator): def __init__(self, provider_credential_schema: ProviderCredentialSchema): self.provider_credential_schema = provider_credential_schema - def validate_and_filter(self, credentials: dict) -> dict: + def validate_and_filter(self, credentials: dict): """ Validate provider credentials diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index f65339fbf..962e41767 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -18,7 +18,7 @@ from pydantic_core import Url from pydantic_extra_types.color import Color -def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: +def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any): return model.model_dump(mode=mode, **kwargs) @@ -100,7 +100,7 @@ def jsonable_encoder( exclude_none: bool = False, custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None, sqlalchemy_safe: bool = True, -) -> Any: +): custom_encoder = custom_encoder or {} if custom_encoder: if type(obj) in custom_encoder: diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index 06d5c02bb..ce7bd2111 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -25,7 +25,7 @@ class ApiModeration(Moderation): name: str = "api" @classmethod - def validate_config(cls, tenant_id: str, config: dict) -> None: + def validate_config(cls, tenant_id: str, config: dict): """ Validate the incoming form config data. @@ -75,7 +75,7 @@ class ApiModeration(Moderation): flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response ) - def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict: + def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict): if self.config is None: raise ValueError("The config is not set.") extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", "")) diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index f07947879..752617b65 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -34,13 +34,13 @@ class Moderation(Extensible, ABC): module: ExtensionModule = ExtensionModule.MODERATION - def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None: + def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None): super().__init__(tenant_id, config) self.app_id = app_id @classmethod @abstractmethod - def validate_config(cls, tenant_id: str, config: dict) -> None: + def validate_config(cls, tenant_id: str, config: dict): """ Validate the incoming form config data. @@ -76,7 +76,7 @@ class Moderation(Extensible, ABC): raise NotImplementedError @classmethod - def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool) -> None: + def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool): # inputs_config inputs_config = config.get("inputs_config") if not isinstance(inputs_config, dict): diff --git a/api/core/moderation/factory.py b/api/core/moderation/factory.py index 9cda24d7a..c2c8be6d6 100644 --- a/api/core/moderation/factory.py +++ b/api/core/moderation/factory.py @@ -6,12 +6,12 @@ from extensions.ext_code_based_extension import code_based_extension class ModerationFactory: __extension_instance: Moderation - def __init__(self, name: str, app_id: str, tenant_id: str, config: dict) -> None: + def __init__(self, name: str, app_id: str, tenant_id: str, config: dict): extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) self.__extension_instance = extension_class(app_id, tenant_id, config) @classmethod - def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: + def validate_config(cls, name: str, tenant_id: str, config: dict): """ Validate the incoming form config data. diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index 9dd2665c3..8d8d15374 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -8,7 +8,7 @@ class KeywordsModeration(Moderation): name: str = "keywords" @classmethod - def validate_config(cls, tenant_id: str, config: dict) -> None: + def validate_config(cls, tenant_id: str, config: dict): """ Validate the incoming form config data. diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index d64f17b38..74ef6f7ce 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -7,7 +7,7 @@ class OpenAIModeration(Moderation): name: str = "openai_moderation" @classmethod - def validate_config(cls, tenant_id: str, config: dict) -> None: + def validate_config(cls, tenant_id: str, config: dict): """ Validate the incoming form config data. diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index f981737df..6993ec8b0 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -40,7 +40,7 @@ class OutputModeration(BaseModel): def get_final_output(self) -> str: return self.final_output or "" - def append_new_token(self, token: str) -> None: + def append_new_token(self, token: str): self.buffer += token if not self.thread: diff --git a/api/core/plugin/backwards_invocation/encrypt.py b/api/core/plugin/backwards_invocation/encrypt.py index 213f5c726..fafc6e894 100644 --- a/api/core/plugin/backwards_invocation/encrypt.py +++ b/api/core/plugin/backwards_invocation/encrypt.py @@ -6,7 +6,7 @@ from models.account import Tenant class PluginEncrypter: @classmethod - def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict: + def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt): encrypter, cache = create_provider_encrypter( tenant_id=tenant.id, config=payload.config, diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index 7898795ce..bed5927e1 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -27,7 +27,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): model_config: ParameterExtractorModelConfig, instruction: str, query: str, - ) -> dict: + ): """ Invoke parameter extractor node. @@ -78,7 +78,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): classes: list[ClassConfig], instruction: str, query: str, - ) -> dict: + ): """ Invoke question classifier node. diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index 01e9e11e6..a6369636e 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -117,7 +117,7 @@ class PluginDeclaration(BaseModel): @model_validator(mode="before") @classmethod - def validate_category(cls, values: dict) -> dict: + def validate_category(cls, values: dict): # auto detect category if values.get("tool"): values["category"] = PluginCategory.Tool @@ -168,7 +168,7 @@ class GenericProviderID: def __str__(self) -> str: return f"{self.organization}/{self.plugin_name}/{self.provider_name}" - def __init__(self, value: str, is_hardcoded: bool = False) -> None: + def __init__(self, value: str, is_hardcoded: bool = False): if not value: raise NotFound("plugin not found, please add plugin") # check if the value is a valid plugin id with format: $organization/$plugin_name/$provider_name @@ -191,14 +191,14 @@ class GenericProviderID: class ModelProviderID(GenericProviderID): - def __init__(self, value: str, is_hardcoded: bool = False) -> None: + def __init__(self, value: str, is_hardcoded: bool = False): super().__init__(value, is_hardcoded) if self.organization == "langgenius" and self.provider_name == "google": self.plugin_name = "gemini" class ToolProviderID(GenericProviderID): - def __init__(self, value: str, is_hardcoded: bool = False) -> None: + def __init__(self, value: str, is_hardcoded: bool = False): super().__init__(value, is_hardcoded) if self.organization == "langgenius": if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]: diff --git a/api/core/plugin/impl/agent.py b/api/core/plugin/impl/agent.py index 3c994ce70..526f6f296 100644 --- a/api/core/plugin/impl/agent.py +++ b/api/core/plugin/impl/agent.py @@ -17,7 +17,7 @@ class PluginAgentClient(BasePluginClient): Fetch agent providers for the given tenant. """ - def transformer(json_response: dict[str, Any]) -> dict: + def transformer(json_response: dict[str, Any]): for provider in json_response.get("data", []): declaration = provider.get("declaration", {}) or {} provider_name = declaration.get("identity", {}).get("name") @@ -49,7 +49,7 @@ class PluginAgentClient(BasePluginClient): """ agent_provider_id = GenericProviderID(provider) - def transformer(json_response: dict[str, Any]) -> dict: + def transformer(json_response: dict[str, Any]): # skip if error occurs if json_response.get("data") is None or json_response.get("data", {}).get("declaration") is None: return json_response diff --git a/api/core/plugin/impl/exc.py b/api/core/plugin/impl/exc.py index 8ecc2e214..23a69bd92 100644 --- a/api/core/plugin/impl/exc.py +++ b/api/core/plugin/impl/exc.py @@ -8,7 +8,7 @@ from extensions.ext_logging import get_request_id class PluginDaemonError(Exception): """Base class for all plugin daemon errors.""" - def __init__(self, description: str) -> None: + def __init__(self, description: str): self.description = description def __str__(self) -> str: diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index f7607eef8..85a72d9f8 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -415,7 +415,7 @@ class PluginModelClient(BasePluginClient): model: str, credentials: dict, language: Optional[str] = None, - ) -> list[dict]: + ): """ Get tts model voices """ diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py index 4c1558efc..7199c0d15 100644 --- a/api/core/plugin/impl/tool.py +++ b/api/core/plugin/impl/tool.py @@ -16,7 +16,7 @@ class PluginToolManager(BasePluginClient): Fetch tool providers for the given tenant. """ - def transformer(json_response: dict[str, Any]) -> dict: + def transformer(json_response: dict[str, Any]): for provider in json_response.get("data", []): declaration = provider.get("declaration", {}) or {} provider_name = declaration.get("identity", {}).get("name") @@ -48,7 +48,7 @@ class PluginToolManager(BasePluginClient): """ tool_provider_id = ToolProviderID(provider) - def transformer(json_response: dict[str, Any]) -> dict: + def transformer(json_response: dict[str, Any]): data = json_response.get("data") if data: for tool in data.get("declaration", {}).get("tools", []): diff --git a/api/core/plugin/utils/chunk_merger.py b/api/core/plugin/utils/chunk_merger.py index 21ca2d8d3..ec66ba02e 100644 --- a/api/core/plugin/utils/chunk_merger.py +++ b/api/core/plugin/utils/chunk_merger.py @@ -18,7 +18,7 @@ class FileChunk: bytes_written: int = field(default=0, init=False) data: bytearray = field(init=False) - def __post_init__(self) -> None: + def __post_init__(self): self.data = bytearray(self.total_length) diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 16c145f93..11c6e5c23 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -30,7 +30,7 @@ class AdvancedPromptTransform(PromptTransform): self, with_variable_tmpl: bool = False, image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW, - ) -> None: + ): self.with_variable_tmpl = with_variable_tmpl self.image_detail_config = image_detail_config diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 13f4163d8..d75a230d7 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -126,7 +126,7 @@ class SimplePromptTransform(PromptTransform): has_context: bool, query_in_prompt: bool, with_memory_prompt: bool = False, - ) -> dict: + ): prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model) custom_variable_keys = [] @@ -277,7 +277,7 @@ class SimplePromptTransform(PromptTransform): return prompt_message - def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict: + def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str): """ Get simple prompt rule. :param app_mode: app mode diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index cdc6ccc82..0a7a46722 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -15,7 +15,7 @@ from core.prompt.simple_prompt_transform import ModelMode class PromptMessageUtil: @staticmethod - def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]) -> list[dict]: + def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]): """ Prompt messages to prompt for saving. :param model_mode: model mode diff --git a/api/core/prompt/utils/prompt_template_parser.py b/api/core/prompt/utils/prompt_template_parser.py index 8e40674bc..1b936c089 100644 --- a/api/core/prompt/utils/prompt_template_parser.py +++ b/api/core/prompt/utils/prompt_template_parser.py @@ -25,7 +25,7 @@ class PromptTemplateParser: self.regex = WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX self.variable_keys = self.extract() - def extract(self) -> list: + def extract(self): # Regular expression to match the template rules return re.findall(self.regex, self.template) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 4a3b8c9dd..13dcef1a1 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -59,7 +59,7 @@ class ProviderManager: ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers. """ - def __init__(self) -> None: + def __init__(self): self.decoding_rsa_key = None self.decoding_cipher_rsa = None diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index 5fb6f9fcc..096f40f70 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -76,7 +76,7 @@ class Jieba(BaseKeyword): return False return id in set.union(*keyword_table.values()) - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): lock_name = f"keyword_indexing_lock_{self.dataset.id}" with redis_client.lock(lock_name, timeout=600): keyword_table = self._get_dataset_keyword_table() @@ -116,7 +116,7 @@ class Jieba(BaseKeyword): return documents - def delete(self) -> None: + def delete(self): lock_name = f"keyword_indexing_lock_{self.dataset.id}" with redis_client.lock(lock_name, timeout=600): dataset_keyword_table = self.dataset.dataset_keyword_table @@ -168,14 +168,14 @@ class Jieba(BaseKeyword): return {} - def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict: + def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]): for keyword in keywords: if keyword not in keyword_table: keyword_table[keyword] = set() keyword_table[keyword].add(id) return keyword_table - def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]) -> dict: + def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]): # get set of ids that correspond to node node_idxs_to_delete = set(ids) @@ -251,7 +251,7 @@ class Jieba(BaseKeyword): self._save_dataset_keyword_table(keyword_table) -def set_orjson_default(obj: Any) -> Any: +def set_orjson_default(obj: Any): """Default function for orjson serialization of set types""" if isinstance(obj, set): return list(obj) diff --git a/api/core/rag/datasource/keyword/keyword_base.py b/api/core/rag/datasource/keyword/keyword_base.py index b261b40b7..0a5985530 100644 --- a/api/core/rag/datasource/keyword/keyword_base.py +++ b/api/core/rag/datasource/keyword/keyword_base.py @@ -24,11 +24,11 @@ class BaseKeyword(ABC): raise NotImplementedError @abstractmethod - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): raise NotImplementedError @abstractmethod - def delete(self) -> None: + def delete(self): raise NotImplementedError @abstractmethod diff --git a/api/core/rag/datasource/keyword/keyword_factory.py b/api/core/rag/datasource/keyword/keyword_factory.py index f1a6ade91..b2e1a55ee 100644 --- a/api/core/rag/datasource/keyword/keyword_factory.py +++ b/api/core/rag/datasource/keyword/keyword_factory.py @@ -36,10 +36,10 @@ class Keyword: def text_exists(self, id: str) -> bool: return self._keyword_processor.text_exists(id) - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): self._keyword_processor.delete_by_ids(ids) - def delete(self) -> None: + def delete(self): self._keyword_processor.delete() def search(self, query: str, **kwargs: Any) -> list[Document]: diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index b9e488362..ddb549ba9 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -46,10 +46,10 @@ class AnalyticdbVector(BaseVector): def text_exists(self, id: str) -> bool: return self.analyticdb_vector.text_exists(id) - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): self.analyticdb_vector.delete_by_ids(ids) - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_metadata_field(self, key: str, value: str): self.analyticdb_vector.delete_by_metadata_field(key, value) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: @@ -58,7 +58,7 @@ class AnalyticdbVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return self.analyticdb_vector.search_by_full_text(query, **kwargs) - def delete(self) -> None: + def delete(self): self.analyticdb_vector.delete() diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py index 48e3f20e3..c3a6127e4 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py @@ -26,7 +26,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): if not values["access_key_id"]: raise ValueError("config ANALYTICDB_KEY_ID is required") if not values["access_key_secret"]: @@ -65,7 +65,7 @@ class AnalyticdbVectorOpenAPI: self._client = Client(self._client_config) self._initialize() - def _initialize(self) -> None: + def _initialize(self): cache_key = f"vector_initialize_{self.config.instance_id}" lock_name = f"{cache_key}_lock" with redis_client.lock(lock_name, timeout=20): @@ -76,7 +76,7 @@ class AnalyticdbVectorOpenAPI: self._create_namespace_if_not_exists() redis_client.set(database_exist_cache_key, 1, ex=3600) - def _initialize_vector_database(self) -> None: + def _initialize_vector_database(self): from alibabacloud_gpdb20160503 import models as gpdb_20160503_models # type: ignore request = gpdb_20160503_models.InitVectorDatabaseRequest( @@ -87,7 +87,7 @@ class AnalyticdbVectorOpenAPI: ) self._client.init_vector_database(request) - def _create_namespace_if_not_exists(self) -> None: + def _create_namespace_if_not_exists(self): from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from Tea.exceptions import TeaException # type: ignore @@ -200,7 +200,7 @@ class AnalyticdbVectorOpenAPI: response = self._client.query_collection_data(request) return len(response.body.matches.match) > 0 - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): from alibabacloud_gpdb20160503 import models as gpdb_20160503_models ids_str = ",".join(f"'{id}'" for id in ids) @@ -216,7 +216,7 @@ class AnalyticdbVectorOpenAPI: ) self._client.delete_collection_data(request) - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_metadata_field(self, key: str, value: str): from alibabacloud_gpdb20160503 import models as gpdb_20160503_models request = gpdb_20160503_models.DeleteCollectionDataRequest( @@ -305,7 +305,7 @@ class AnalyticdbVectorOpenAPI: documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents - def delete(self) -> None: + def delete(self): try: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py index d1de43c5e..12126f32d 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py @@ -23,7 +23,7 @@ class AnalyticdbVectorBySqlConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): if not values["host"]: raise ValueError("config ANALYTICDB_HOST is required") if not values["port"]: @@ -52,7 +52,7 @@ class AnalyticdbVectorBySql: if not self.pool: self.pool = self._create_connection_pool() - def _initialize(self) -> None: + def _initialize(self): cache_key = f"vector_initialize_{self.config.host}" lock_name = f"{cache_key}_lock" with redis_client.lock(lock_name, timeout=20): @@ -85,7 +85,7 @@ class AnalyticdbVectorBySql: conn.commit() self.pool.putconn(conn) - def _initialize_vector_database(self) -> None: + def _initialize_vector_database(self): conn = psycopg2.connect( host=self.config.host, port=self.config.port, @@ -188,7 +188,7 @@ class AnalyticdbVectorBySql: cur.execute(f"SELECT id FROM {self.table_name} WHERE ref_doc_id = %s", (id,)) return cur.fetchone() is not None - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): if not ids: return with self._get_cursor() as cur: @@ -198,7 +198,7 @@ class AnalyticdbVectorBySql: if "does not exist" not in str(e): raise e - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_metadata_field(self, key: str, value: str): with self._get_cursor() as cur: try: cur.execute(f"DELETE FROM {self.table_name} WHERE metadata_->>%s = %s", (key, value)) @@ -270,6 +270,6 @@ class AnalyticdbVectorBySql: documents.append(doc) return documents - def delete(self) -> None: + def delete(self): with self._get_cursor() as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index d30cf4260..aa980f383 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -36,7 +36,7 @@ class BaiduConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): if not values["endpoint"]: raise ValueError("config BAIDU_VECTOR_DB_ENDPOINT is required") if not values["account"]: @@ -66,7 +66,7 @@ class BaiduVector(BaseVector): def get_type(self) -> str: return VectorType.BAIDU - def to_index_struct(self) -> dict: + def to_index_struct(self): return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): @@ -111,13 +111,13 @@ class BaiduVector(BaseVector): return True return False - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): if not ids: return quoted_ids = [f"'{id}'" for id in ids] self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})") - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_metadata_field(self, key: str, value: str): self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'") def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: @@ -164,7 +164,7 @@ class BaiduVector(BaseVector): return docs - def delete(self) -> None: + def delete(self): try: self._db.drop_table(table_name=self._collection_name) except ServerError as e: @@ -201,7 +201,7 @@ class BaiduVector(BaseVector): tables = self._db.list_table() return any(table.table_name == self._collection_name for table in tables) - def _create_table(self, dimension: int) -> None: + def _create_table(self, dimension: int): # Try to grab distributed lock and create table lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=60): diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index 88da86cf7..e7128b183 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -82,7 +82,7 @@ class ChromaVector(BaseVector): def delete(self): self._client.delete_collection(self._collection_name) - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): if not ids: return collection = self._client.get_or_create_collection(self._collection_name) diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index 505cfb4c1..eb4cbd232 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -49,7 +49,7 @@ class ClickzettaConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): """ Validate the configuration values. """ @@ -134,7 +134,7 @@ class ClickzettaConnectionPool: raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts") - def _configure_connection(self, connection: "Connection") -> None: + def _configure_connection(self, connection: "Connection"): """Configure connection session settings.""" try: with connection.cursor() as cursor: @@ -221,7 +221,7 @@ class ClickzettaConnectionPool: # No valid connection found, create new one return self._create_connection(config) - def return_connection(self, config: ClickzettaConfig, connection: "Connection") -> None: + def return_connection(self, config: ClickzettaConfig, connection: "Connection"): """Return a connection to the pool.""" config_key = self._get_config_key(config) @@ -243,7 +243,7 @@ class ClickzettaConnectionPool: with contextlib.suppress(Exception): connection.close() - def _cleanup_expired_connections(self) -> None: + def _cleanup_expired_connections(self): """Clean up expired connections from all pools.""" current_time = time.time() @@ -265,7 +265,7 @@ class ClickzettaConnectionPool: self._pools[config_key] = valid_connections - def _start_cleanup_thread(self) -> None: + def _start_cleanup_thread(self): """Start background thread for connection cleanup.""" def cleanup_worker(): @@ -280,7 +280,7 @@ class ClickzettaConnectionPool: self._cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True) self._cleanup_thread.start() - def shutdown(self) -> None: + def shutdown(self): """Shutdown connection pool and close all connections.""" self._shutdown = True @@ -319,7 +319,7 @@ class ClickzettaVector(BaseVector): """Get a connection from the pool.""" return self._connection_pool.get_connection(self._config) - def _return_connection(self, connection: "Connection") -> None: + def _return_connection(self, connection: "Connection"): """Return a connection to the pool.""" self._connection_pool.return_connection(self._config, connection) @@ -342,7 +342,7 @@ class ClickzettaVector(BaseVector): """Get a connection context manager.""" return self.ConnectionContext(self) - def _parse_metadata(self, raw_metadata: str, row_id: str) -> dict: + def _parse_metadata(self, raw_metadata: str, row_id: str): """ Parse metadata from JSON string with proper error handling and fallback. @@ -723,7 +723,7 @@ class ClickzettaVector(BaseVector): result = cursor.fetchone() return result[0] > 0 if result else False - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): """Delete documents by IDs.""" if not ids: return @@ -736,7 +736,7 @@ class ClickzettaVector(BaseVector): # Execute delete through write queue self._execute_write(self._delete_by_ids_impl, ids) - def _delete_by_ids_impl(self, ids: list[str]) -> None: + def _delete_by_ids_impl(self, ids: list[str]): """Implementation of delete by IDs (executed in write worker thread).""" safe_ids = [self._safe_doc_id(id) for id in ids] @@ -748,7 +748,7 @@ class ClickzettaVector(BaseVector): with connection.cursor() as cursor: cursor.execute(sql, binding_params=safe_ids) - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_metadata_field(self, key: str, value: str): """Delete documents by metadata field.""" # Check if table exists before attempting delete if not self._table_exists(): @@ -758,7 +758,7 @@ class ClickzettaVector(BaseVector): # Execute delete through write queue self._execute_write(self._delete_by_metadata_field_impl, key, value) - def _delete_by_metadata_field_impl(self, key: str, value: str) -> None: + def _delete_by_metadata_field_impl(self, key: str, value: str): """Implementation of delete by metadata field (executed in write worker thread).""" with self.get_connection_context() as connection: with connection.cursor() as cursor: @@ -1027,7 +1027,7 @@ class ClickzettaVector(BaseVector): return documents - def delete(self) -> None: + def delete(self): """Delete the entire collection.""" with self.get_connection_context() as connection: with connection.cursor() as cursor: diff --git a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py index 9c34f51c6..6df909ca9 100644 --- a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py +++ b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py @@ -36,7 +36,7 @@ class CouchbaseConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): if not values.get("connection_string"): raise ValueError("config COUCHBASE_CONNECTION_STRING is required") if not values.get("user"): @@ -234,7 +234,7 @@ class CouchbaseVector(BaseVector): return bool(row["count"] > 0) return False # Return False if no rows are returned - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): query = f""" DELETE FROM `{self._bucket_name}`.{self._client_config.scope_name}.{self._collection_name} WHERE META().id IN $doc_ids; @@ -261,7 +261,7 @@ class CouchbaseVector(BaseVector): # result = self._cluster.query(query, named_parameters={'value':value}) # return [row['id'] for row in result.rows()] - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_metadata_field(self, key: str, value: str): query = f""" DELETE FROM `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name} WHERE metadata.{key} = $value; diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 4e288ccc0..df1c74758 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -43,7 +43,7 @@ class ElasticSearchConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): use_cloud = values.get("use_cloud", False) cloud_url = values.get("cloud_url") @@ -174,20 +174,20 @@ class ElasticSearchVector(BaseVector): def text_exists(self, id: str) -> bool: return bool(self._client.exists(index=self._collection_name, id=id)) - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): if not ids: return for id in ids: self._client.delete(index=self._collection_name, id=id) - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_metadata_field(self, key: str, value: str): query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}} results = self._client.search(index=self._collection_name, body=query_str) ids = [hit["_id"] for hit in results["hits"]["hits"]] if ids: self.delete_by_ids(ids) - def delete(self) -> None: + def delete(self): self._client.indices.delete(index=self._collection_name) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py index f0d014b1e..107ea75e6 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py @@ -33,7 +33,7 @@ class HuaweiCloudVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): if not values["hosts"]: raise ValueError("config HOSTS is required") return values @@ -78,20 +78,20 @@ class HuaweiCloudVector(BaseVector): def text_exists(self, id: str) -> bool: return bool(self._client.exists(index=self._collection_name, id=id)) - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): if not ids: return for id in ids: self._client.delete(index=self._collection_name, id=id) - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_metadata_field(self, key: str, value: str): query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}} results = self._client.search(index=self._collection_name, body=query_str) ids = [hit["_id"] for hit in results["hits"]["hits"]] if ids: self.delete_by_ids(ids) - def delete(self) -> None: + def delete(self): self._client.indices.delete(index=self._collection_name) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index cba10b5aa..5097412c2 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -36,7 +36,7 @@ class LindormVectorStoreConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): if not values["hosts"]: raise ValueError("config URL is required") if not values["username"]: @@ -167,7 +167,7 @@ class LindormVectorStore(BaseVector): if ids: self.delete_by_ids(ids) - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): """Delete documents by their IDs in batch. Args: @@ -213,7 +213,7 @@ class LindormVectorStore(BaseVector): else: logger.exception("Error deleting document: %s", error) - def delete(self) -> None: + def delete(self): if self._using_ugc: routing_filter_query = { "query": {"bool": {"must": [{"term": {f"{self._routing_field}.keyword": self._routing}}]}} @@ -372,7 +372,7 @@ class LindormVectorStore(BaseVector): # logger.info(f"create index success: {self._collection_name}") -def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict: +def default_text_mapping(dimension: int, method_name: str, **kwargs: Any): excludes_from_source = kwargs.get("excludes_from_source", False) analyzer = kwargs.get("analyzer", "ik_max_word") text_field = kwargs.get("text_field", Field.CONTENT_KEY.value) @@ -456,7 +456,7 @@ def default_text_search_query( routing: Optional[str] = None, routing_field: Optional[str] = None, **kwargs, -) -> dict: +): query_clause: dict[str, Any] = {} if routing is not None: query_clause = { @@ -513,7 +513,7 @@ def default_vector_search_query( filters: Optional[list[dict]] = None, filter_type: Optional[str] = None, **kwargs, -) -> dict: +): if filters is not None: filter_type = "pre_filter" if filter_type is None else filter_type if not isinstance(filters, list): diff --git a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py index 564f8fc20..1bf8da5da 100644 --- a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py +++ b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py @@ -29,7 +29,7 @@ class MatrixoneConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): if not values["host"]: raise ValueError("config host is required") if not values["port"]: @@ -128,7 +128,7 @@ class MatrixoneVector(BaseVector): return len(result) > 0 @ensure_client - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): assert self.client is not None if not ids: return @@ -141,7 +141,7 @@ class MatrixoneVector(BaseVector): return [result.id for result in results] @ensure_client - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_metadata_field(self, key: str, value: str): assert self.client is not None self.client.delete(filter={key: value}) @@ -207,7 +207,7 @@ class MatrixoneVector(BaseVector): return docs @ensure_client - def delete(self) -> None: + def delete(self): assert self.client is not None self.client.delete() diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 4ad0fada1..2ec48ae36 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -36,7 +36,7 @@ class MilvusConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): """ Validate the configuration values. Raises ValueError if required fields are missing. @@ -79,7 +79,7 @@ class MilvusVector(BaseVector): self._load_collection_fields() self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported - def _load_collection_fields(self, fields: Optional[list[str]] = None) -> None: + def _load_collection_fields(self, fields: Optional[list[str]] = None): if fields is None: # Load collection fields from remote server collection_info = self._client.describe_collection(self._collection_name) @@ -171,7 +171,7 @@ class MilvusVector(BaseVector): if ids: self._client.delete(collection_name=self._collection_name, pks=ids) - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): """ Delete documents by their IDs. """ @@ -183,7 +183,7 @@ class MilvusVector(BaseVector): ids = [item["id"] for item in result] self._client.delete(collection_name=self._collection_name, pks=ids) - def delete(self) -> None: + def delete(self): """ Delete the entire collection. """ diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index d048f3b34..b590a4dfe 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -101,7 +101,7 @@ class MyScaleVector(BaseVector): results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'") return results.row_count > 0 - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): if not ids: return self._client.command( @@ -114,7 +114,7 @@ class MyScaleVector(BaseVector): ).result_rows return [row[0] for row in rows] - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_metadata_field(self, key: str, value: str): self._client.command( f"DELETE FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'" ) @@ -156,7 +156,7 @@ class MyScaleVector(BaseVector): logger.exception("Vector search operation failed") return [] - def delete(self) -> None: + def delete(self): self._client.command(f"DROP TABLE IF EXISTS {self._config.database}.{self._collection_name}") diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py index 556d03940..44adf22d0 100644 --- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py @@ -35,7 +35,7 @@ class OceanBaseVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): if not values["host"]: raise ValueError("config OCEANBASE_VECTOR_HOST is required") if not values["port"]: @@ -68,7 +68,7 @@ class OceanBaseVector(BaseVector): self._create_collection() self.add_texts(texts, embeddings) - def _create_collection(self) -> None: + def _create_collection(self): lock_name = "vector_indexing_lock_" + self._collection_name with redis_client.lock(lock_name, timeout=20): collection_exist_cache_key = "vector_indexing_" + self._collection_name @@ -174,7 +174,7 @@ class OceanBaseVector(BaseVector): cur = self._client.get(table_name=self._collection_name, ids=id) return bool(cur.rowcount != 0) - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): if not ids: return self._client.delete(table_name=self._collection_name, ids=ids) @@ -190,7 +190,7 @@ class OceanBaseVector(BaseVector): ) return [row[0] for row in cur] - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_metadata_field(self, key: str, value: str): ids = self.get_ids_by_metadata_field(key, value) self.delete_by_ids(ids) @@ -278,7 +278,7 @@ class OceanBaseVector(BaseVector): ) return docs - def delete(self) -> None: + def delete(self): self._client.drop_table_if_exist(self._collection_name) diff --git a/api/core/rag/datasource/vdb/opengauss/opengauss.py b/api/core/rag/datasource/vdb/opengauss/opengauss.py index c448210d9..f9dbfbeea 100644 --- a/api/core/rag/datasource/vdb/opengauss/opengauss.py +++ b/api/core/rag/datasource/vdb/opengauss/opengauss.py @@ -29,7 +29,7 @@ class OpenGaussConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): if not values["host"]: raise ValueError("config OPENGAUSS_HOST is required") if not values["port"]: @@ -159,7 +159,7 @@ class OpenGauss(BaseVector): docs.append(Document(page_content=record[1], metadata=record[0])) return docs - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): # Avoiding crashes caused by performing delete operations on empty lists in certain scenarios # Scenario 1: extract a document fails, resulting in a table not being created. # Then clicking the retry button triggers a delete operation on an empty list. @@ -168,7 +168,7 @@ class OpenGauss(BaseVector): with self._get_cursor() as cur: cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_metadata_field(self, key: str, value: str): with self._get_cursor() as cur: cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value)) @@ -222,7 +222,7 @@ class OpenGauss(BaseVector): return docs - def delete(self) -> None: + def delete(self): with self._get_cursor() as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 917c27eab..3f65a4a27 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -33,7 +33,7 @@ class OpenSearchConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): if not values.get("host"): raise ValueError("config OPENSEARCH_HOST is required") if not values.get("port"): @@ -128,7 +128,7 @@ class OpenSearchVector(BaseVector): if ids: self.delete_by_ids(ids) - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): index_name = self._collection_name.lower() if not self._client.indices.exists(index=index_name): logger.warning("Index %s does not exist", index_name) @@ -159,7 +159,7 @@ class OpenSearchVector(BaseVector): else: logger.exception("Error deleting document: %s", error) - def delete(self) -> None: + def delete(self): self._client.indices.delete(index=self._collection_name.lower()) def text_exists(self, id: str) -> bool: diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 1b99f649b..23997d3d2 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -33,7 +33,7 @@ class OracleVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): if not values["user"]: raise ValueError("config ORACLE_USER is required") if not values["password"]: @@ -206,7 +206,7 @@ class OracleVector(BaseVector): conn.close() return docs - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): if not ids: return with self._get_connection() as conn: @@ -216,7 +216,7 @@ class OracleVector(BaseVector): conn.commit() conn.close() - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_metadata_field(self, key: str, value: str): with self._get_connection() as conn: with conn.cursor() as cur: cur.execute(f"DELETE FROM {self.table_name} WHERE JSON_VALUE(meta, '$." + key + "') = :1", (value,)) @@ -336,7 +336,7 @@ class OracleVector(BaseVector): else: return [Document(page_content="", metadata={})] - def delete(self) -> None: + def delete(self): with self._get_connection() as conn: with conn.cursor() as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints") diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index 99cd4a22c..b986c79e3 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -33,7 +33,7 @@ class PgvectoRSConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): if not values["host"]: raise ValueError("config PGVECTO_RS_HOST is required") if not values["port"]: @@ -150,7 +150,7 @@ class PGVectoRS(BaseVector): session.execute(select_statement, {"ids": ids}) session.commit() - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): with Session(self._client) as session: select_statement = sql_text( f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = ANY (:doc_ids); " @@ -164,7 +164,7 @@ class PGVectoRS(BaseVector): session.execute(select_statement, {"ids": ids}) session.commit() - def delete(self) -> None: + def delete(self): with Session(self._client) as session: session.execute(sql_text(f"DROP TABLE IF EXISTS {self._collection_name}")) session.commit() diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 13be18f92..445a0a7f8 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -34,7 +34,7 @@ class PGVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): if not values["host"]: raise ValueError("config PGVECTOR_HOST is required") if not values["port"]: @@ -146,7 +146,7 @@ class PGVector(BaseVector): docs.append(Document(page_content=record[1], metadata=record[0])) return docs - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): # Avoiding crashes caused by performing delete operations on empty lists in certain scenarios # Scenario 1: extract a document fails, resulting in a table not being created. # Then clicking the retry button triggers a delete operation on an empty list. @@ -162,7 +162,7 @@ class PGVector(BaseVector): except Exception as e: raise e - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_metadata_field(self, key: str, value: str): with self._get_cursor() as cur: cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value)) @@ -242,7 +242,7 @@ class PGVector(BaseVector): return docs - def delete(self) -> None: + def delete(self): with self._get_cursor() as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") diff --git a/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py b/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py index c33e344bf..86b6ace3f 100644 --- a/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py +++ b/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py @@ -28,7 +28,7 @@ class VastbaseVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): if not values["host"]: raise ValueError("config VASTBASE_HOST is required") if not values["port"]: @@ -133,7 +133,7 @@ class VastbaseVector(BaseVector): docs.append(Document(page_content=record[1], metadata=record[0])) return docs - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): # Avoiding crashes caused by performing delete operations on empty lists in certain scenarios # Scenario 1: extract a document fails, resulting in a table not being created. # Then clicking the retry button triggers a delete operation on an empty list. @@ -142,7 +142,7 @@ class VastbaseVector(BaseVector): with self._get_cursor() as cur: cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_metadata_field(self, key: str, value: str): with self._get_cursor() as cur: cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value)) @@ -199,7 +199,7 @@ class VastbaseVector(BaseVector): return docs - def delete(self) -> None: + def delete(self): with self._get_cursor() as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index e55c06e66..12d97c500 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -81,7 +81,7 @@ class QdrantVector(BaseVector): def get_type(self) -> str: return VectorType.QDRANT - def to_index_struct(self) -> dict: + def to_index_struct(self): return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): @@ -292,7 +292,7 @@ class QdrantVector(BaseVector): else: raise e - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index a200bacfb..9d3dc7c62 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -35,7 +35,7 @@ class RelytConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): if not values["host"]: raise ValueError("config RELYT_HOST is required") if not values["port"]: @@ -64,7 +64,7 @@ class RelytVector(BaseVector): def get_type(self) -> str: return VectorType.RELYT - def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> None: + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): self.create_collection(len(embeddings[0])) self.embedding_dimension = len(embeddings[0]) self.add_texts(texts, embeddings) @@ -196,7 +196,7 @@ class RelytVector(BaseVector): if ids: self.delete_by_uuids(ids) - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): with Session(self.client) as session: ids_str = ",".join(f"'{doc_id}'" for doc_id in ids) select_statement = sql_text( @@ -207,7 +207,7 @@ class RelytVector(BaseVector): ids = [item[0] for item in result] self.delete_by_uuids(ids) - def delete(self) -> None: + def delete(self): with Session(self.client) as session: session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";""")) session.commit() diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py index 9c5535152..27685b7dd 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py @@ -30,7 +30,7 @@ class TableStoreConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): if not values["access_key_id"]: raise ValueError("config ACCESS_KEY_ID is required") if not values["access_key_secret"]: @@ -112,7 +112,7 @@ class TableStoreVector(BaseVector): return return_row is not None - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): if not ids: return for id in ids: @@ -121,7 +121,7 @@ class TableStoreVector(BaseVector): def get_ids_by_metadata_field(self, key: str, value: str): return self._search_by_metadata(key, value) - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_metadata_field(self, key: str, value: str): ids = self.get_ids_by_metadata_field(key, value) self.delete_by_ids(ids) @@ -143,7 +143,7 @@ class TableStoreVector(BaseVector): score_threshold = float(kwargs.get("score_threshold") or 0.0) return self._search_by_full_text(query, filtered_list, top_k, score_threshold) - def delete(self) -> None: + def delete(self): self._delete_table_if_exist() def _create_collection(self, dimension: int): @@ -158,7 +158,7 @@ class TableStoreVector(BaseVector): self._create_search_index_if_not_exist(dimension) redis_client.set(collection_exist_cache_key, 1, ex=3600) - def _create_table_if_not_exist(self) -> None: + def _create_table_if_not_exist(self): table_list = self._tablestore_client.list_table() if self._table_name in table_list: logger.info("Tablestore system table[%s] already exists", self._table_name) @@ -171,7 +171,7 @@ class TableStoreVector(BaseVector): self._tablestore_client.create_table(table_meta, table_options, reserved_throughput) logger.info("Tablestore create table[%s] successfully.", self._table_name) - def _create_search_index_if_not_exist(self, dimension: int) -> None: + def _create_search_index_if_not_exist(self, dimension: int): search_index_list = self._tablestore_client.list_search_index(table_name=self._table_name) assert isinstance(search_index_list, Iterable) if self._index_name in [t[1] for t in search_index_list]: @@ -225,11 +225,11 @@ class TableStoreVector(BaseVector): self._tablestore_client.delete_table(self._table_name) logger.info("Tablestore delete system table[%s] successfully.", self._index_name) - def _delete_search_index(self) -> None: + def _delete_search_index(self): self._tablestore_client.delete_search_index(self._table_name, self._index_name) logger.info("Tablestore delete index[%s] successfully.", self._index_name) - def _write_row(self, primary_key: str, attributes: dict[str, Any]) -> None: + def _write_row(self, primary_key: str, attributes: dict[str, Any]): pk = [("id", primary_key)] tags = [] @@ -248,7 +248,7 @@ class TableStoreVector(BaseVector): row = tablestore.Row(pk, attribute_columns) self._tablestore_client.put_row(self._table_name, row) - def _delete_row(self, id: str) -> None: + def _delete_row(self, id: str): primary_key = [("id", id)] row = tablestore.Row(primary_key) self._tablestore_client.delete_row(self._table_name, row, None) diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 3df35d081..4af34bbb2 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -82,7 +82,7 @@ class TencentVector(BaseVector): def get_type(self) -> str: return VectorType.TENCENT - def to_index_struct(self) -> dict: + def to_index_struct(self): return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} def _has_collection(self) -> bool: @@ -92,7 +92,7 @@ class TencentVector(BaseVector): ) ) - def _create_collection(self, dimension: int) -> None: + def _create_collection(self, dimension: int): self._dimension = dimension lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): @@ -205,7 +205,7 @@ class TencentVector(BaseVector): return True return False - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): if not ids: return @@ -222,7 +222,7 @@ class TencentVector(BaseVector): database_name=self._client_config.database, collection_name=self.collection_name, document_ids=batch_ids ) - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_metadata_field(self, key: str, value: str): self._client.delete( database_name=self._client_config.database, collection_name=self.collection_name, @@ -299,7 +299,7 @@ class TencentVector(BaseVector): docs.append(doc) return docs - def delete(self) -> None: + def delete(self): if self._has_collection(): self._client.drop_collection( database_name=self._client_config.database, collection_name=self.collection_name diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index be24f5a56..705558145 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -90,7 +90,7 @@ class TidbOnQdrantVector(BaseVector): def get_type(self) -> str: return VectorType.TIDB_ON_QDRANT - def to_index_struct(self) -> dict: + def to_index_struct(self): return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): @@ -284,7 +284,7 @@ class TidbOnQdrantVector(BaseVector): else: raise e - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index e5492cb7f..6efc04aa2 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -31,7 +31,7 @@ class TiDBVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): if not values["host"]: raise ValueError("config TIDB_VECTOR_HOST is required") if not values["port"]: @@ -144,7 +144,7 @@ class TiDBVector(BaseVector): result = self.get_ids_by_metadata_field("doc_id", id) return bool(result) - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): with Session(self._engine) as session: ids_str = ",".join(f"'{doc_id}'" for doc_id in ids) select_statement = sql_text( @@ -179,7 +179,7 @@ class TiDBVector(BaseVector): else: return None - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_metadata_field(self, key: str, value: str): ids = self.get_ids_by_metadata_field(key, value) if ids: self._delete_by_ids(ids) @@ -237,7 +237,7 @@ class TiDBVector(BaseVector): # tidb doesn't support bm25 search return [] - def delete(self) -> None: + def delete(self): with Session(self._engine) as session: session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};""")) session.commit() diff --git a/api/core/rag/datasource/vdb/upstash/upstash_vector.py b/api/core/rag/datasource/vdb/upstash/upstash_vector.py index 9e99f14dc..289d97185 100644 --- a/api/core/rag/datasource/vdb/upstash/upstash_vector.py +++ b/api/core/rag/datasource/vdb/upstash/upstash_vector.py @@ -20,7 +20,7 @@ class UpstashVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): if not values["url"]: raise ValueError("Upstash URL is required") if not values["token"]: @@ -60,7 +60,7 @@ class UpstashVector(BaseVector): response = self.get_ids_by_metadata_field("doc_id", id) return len(response) > 0 - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): item_ids = [] for doc_id in ids: ids = self.get_ids_by_metadata_field("doc_id", doc_id) @@ -68,7 +68,7 @@ class UpstashVector(BaseVector): item_ids += ids self._delete_by_ids(ids=item_ids) - def _delete_by_ids(self, ids: list[str]) -> None: + def _delete_by_ids(self, ids: list[str]): if ids: self.index.delete(ids=ids) @@ -81,7 +81,7 @@ class UpstashVector(BaseVector): ) return [result.id for result in query_result] - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_metadata_field(self, key: str, value: str): ids = self.get_ids_by_metadata_field(key, value) if ids: self._delete_by_ids(ids) @@ -117,7 +117,7 @@ class UpstashVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return [] - def delete(self) -> None: + def delete(self): self.index.reset() def get_type(self) -> str: diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py index edfce2edd..469978224 100644 --- a/api/core/rag/datasource/vdb/vector_base.py +++ b/api/core/rag/datasource/vdb/vector_base.py @@ -27,14 +27,14 @@ class BaseVector(ABC): raise NotImplementedError @abstractmethod - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): raise NotImplementedError def get_ids_by_metadata_field(self, key: str, value: str): raise NotImplementedError @abstractmethod - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_metadata_field(self, key: str, value: str): raise NotImplementedError @abstractmethod @@ -46,7 +46,7 @@ class BaseVector(ABC): raise NotImplementedError @abstractmethod - def delete(self) -> None: + def delete(self): raise NotImplementedError def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 661a8f37a..b2cc51d03 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -26,7 +26,7 @@ class AbstractVectorFactory(ABC): raise NotImplementedError @staticmethod - def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> dict: + def gen_index_struct_dict(vector_type: VectorType, collection_name: str): index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}} return index_struct_dict @@ -207,10 +207,10 @@ class Vector: def text_exists(self, id: str) -> bool: return self._vector_processor.text_exists(id) - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): self._vector_processor.delete_by_ids(ids) - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_metadata_field(self, key: str, value: str): self._vector_processor.delete_by_metadata_field(key, value) def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]: @@ -220,7 +220,7 @@ class Vector: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return self._vector_processor.search_by_full_text(query, **kwargs) - def delete(self) -> None: + def delete(self): self._vector_processor.delete() # delete collection redis cache if self._vector_processor.collection_name: diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py index 33267741c..d1bdd3bae 100644 --- a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py +++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py @@ -144,7 +144,7 @@ class VikingDBVector(BaseVector): return True return False - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): self._client.get_collection(self._collection_name).delete_data(ids) def get_ids_by_metadata_field(self, key: str, value: str): @@ -168,7 +168,7 @@ class VikingDBVector(BaseVector): ids.append(result.id) return ids - def delete_by_metadata_field(self, key: str, value: str) -> None: + def delete_by_metadata_field(self, key: str, value: str): ids = self.get_ids_by_metadata_field(key, value) self.delete_by_ids(ids) @@ -202,7 +202,7 @@ class VikingDBVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return [] - def delete(self) -> None: + def delete(self): if self._has_index(): self._client.drop_index(self._collection_name, self._index_name) if self._has_collection(): diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index bc237b591..43dde37c7 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -24,7 +24,7 @@ class WeaviateConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): if not values["endpoint"]: raise ValueError("config WEAVIATE_ENDPOINT is required") return values @@ -75,7 +75,7 @@ class WeaviateVector(BaseVector): dataset_id = dataset.id return Dataset.gen_collection_name_by_id(dataset_id) - def to_index_struct(self) -> dict: + def to_index_struct(self): return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): @@ -164,7 +164,7 @@ class WeaviateVector(BaseVector): return True - def delete_by_ids(self, ids: list[str]) -> None: + def delete_by_ids(self, ids: list[str]): # check whether the index already exists schema = self._default_schema(self._collection_name) if self._client.schema.contains(schema): @@ -256,7 +256,7 @@ class WeaviateVector(BaseVector): docs.append(Document(page_content=text, vector=additional["vector"], metadata=res)) return docs - def _default_schema(self, index_name: str) -> dict: + def _default_schema(self, index_name: str): return { "class": index_name, "properties": [ @@ -267,7 +267,7 @@ class WeaviateVector(BaseVector): ], } - def _json_serializable(self, value: Any) -> Any: + def _json_serializable(self, value: Any): if isinstance(value, datetime.datetime): return value.isoformat() return value diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 717cfe8f5..63c6db8d0 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -32,11 +32,11 @@ class DatasetDocumentStore: } @property - def dataset_id(self) -> Any: + def dataset_id(self): return self._dataset.id @property - def user_id(self) -> Any: + def user_id(self): return self._user_id @property @@ -59,7 +59,7 @@ class DatasetDocumentStore: return output - def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None: + def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False): max_position = ( db.session.query(func.max(DocumentSegment.position)) .where(DocumentSegment.document_id == self._document_id) @@ -195,7 +195,7 @@ class DatasetDocumentStore: }, ) - def delete_document(self, doc_id: str, raise_error: bool = True) -> None: + def delete_document(self, doc_id: str, raise_error: bool = True): document_segment = self.get_document_segment(doc_id) if document_segment is None: @@ -207,7 +207,7 @@ class DatasetDocumentStore: db.session.delete(document_segment) db.session.commit() - def set_document_hash(self, doc_id: str, doc_hash: str) -> None: + def set_document_hash(self, doc_id: str, doc_hash: str): """Set the hash for a given doc_id.""" document_segment = self.get_document_segment(doc_id) diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index e27c1f059..43be9cde6 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) class CacheEmbedding(Embeddings): - def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> None: + def __init__(self, model_instance: ModelInstance, user: Optional[str] = None): self._model_instance = model_instance self._user = user diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index 1593ad147..52d64f591 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -18,7 +18,7 @@ class NotionInfo(BaseModel): tenant_id: str model_config = ConfigDict(arbitrary_types_allowed=True) - def __init__(self, **data) -> None: + def __init__(self, **data): super().__init__(**data) @@ -49,5 +49,5 @@ class ExtractSetting(BaseModel): document_model: Optional[str] = None model_config = ConfigDict(arbitrary_types_allowed=True) - def __init__(self, **data) -> None: + def __init__(self, **data): super().__init__(**data) diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index fd60af0f1..e1ba6ef24 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -122,7 +122,7 @@ class FirecrawlApp: return response return response - def _handle_error(self, response, action) -> None: + def _handle_error(self, response, action): error_message = response.json().get("error", "Unknown error occurred") raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return] diff --git a/api/core/rag/extractor/helpers.py b/api/core/rag/extractor/helpers.py index 3d2fb55d9..17f7d8661 100644 --- a/api/core/rag/extractor/helpers.py +++ b/api/core/rag/extractor/helpers.py @@ -29,7 +29,7 @@ def detect_file_encodings(file_path: str, timeout: int = 5, sample_size: int = 1 """ import chardet - def read_and_detect(file_path: str) -> list[dict]: + def read_and_detect(file_path: str): with open(file_path, "rb") as f: # Read only a sample of the file for encoding detection # This prevents timeout on large files while still providing accurate encoding detection diff --git a/api/core/rag/extractor/watercrawl/provider.py b/api/core/rag/extractor/watercrawl/provider.py index da03fc67a..c59a70ea5 100644 --- a/api/core/rag/extractor/watercrawl/provider.py +++ b/api/core/rag/extractor/watercrawl/provider.py @@ -9,7 +9,7 @@ class WaterCrawlProvider: def __init__(self, api_key, base_url: str | None = None): self.client = WaterCrawlAPIClient(api_key, base_url) - def crawl_url(self, url, options: Optional[dict | Any] = None) -> dict: + def crawl_url(self, url, options: Optional[dict | Any] = None): options = options or {} spider_options = { "max_depth": 1, @@ -41,7 +41,7 @@ class WaterCrawlProvider: return {"status": "active", "job_id": result.get("uuid")} - def get_crawl_status(self, crawl_request_id) -> dict: + def get_crawl_status(self, crawl_request_id): response = self.client.get_crawl_request(crawl_request_id) data = [] if response["status"] in ["new", "running"]: @@ -82,11 +82,11 @@ class WaterCrawlProvider: return None - def scrape_url(self, url: str) -> dict: + def scrape_url(self, url: str): response = self.client.scrape_url(url=url, sync=True, prefetched=True) return self._structure_data(response) - def _structure_data(self, result_object: dict) -> dict: + def _structure_data(self, result_object: dict): if isinstance(result_object.get("result", {}), str): raise ValueError("Invalid result object. Expected a dictionary.") diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index f3b162e3d..f25f92cf8 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -56,7 +56,7 @@ class WordExtractor(BaseExtractor): elif not os.path.isfile(self.file_path): raise ValueError(f"File path {self.file_path} is not a valid file or url") - def __del__(self) -> None: + def __del__(self): if hasattr(self, "temp_file"): self.temp_file.close() diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 693535413..7a6ebd1f3 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -6,7 +6,7 @@ from core.rag.rerank.rerank_base import BaseRerankRunner class RerankModelRunner(BaseRerankRunner): - def __init__(self, rerank_model_instance: ModelInstance) -> None: + def __init__(self, rerank_model_instance: ModelInstance): self.rerank_model_instance = rerank_model_instance def run( diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 80de746e2..ab49e43b7 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -14,7 +14,7 @@ from core.rag.rerank.rerank_base import BaseRerankRunner class WeightRerankRunner(BaseRerankRunner): - def __init__(self, tenant_id: str, weights: Weights) -> None: + def __init__(self, tenant_id: str, weights: Weights): self.tenant_id = tenant_id self.weights = weights diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 2e63ecfc5..93bad23f2 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -507,7 +507,7 @@ class DatasetRetrieval: def _on_retrieval_end( self, documents: list[Document], message_id: Optional[str] = None, timer: Optional[dict] = None - ) -> None: + ): """Handle retrieval end.""" dify_documents = [document for document in documents if document.provider == "dify"] for document in dify_documents: @@ -560,7 +560,7 @@ class DatasetRetrieval: ) ) - def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str) -> None: + def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str): """ Handle query. """ diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index 1b60fb778..c5b6ac460 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -47,7 +47,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): length_function: Callable[[list[str]], list[int]] = lambda x: [len(x) for x in x], keep_separator: bool = False, add_start_index: bool = False, - ) -> None: + ): """Create a new TextSplitter. Args: @@ -201,7 +201,7 @@ class TokenTextSplitter(TextSplitter): allowed_special: Union[Literal["all"], Set[str]] = set(), disallowed_special: Union[Literal["all"], Collection[str]] = "all", **kwargs: Any, - ) -> None: + ): """Create a new TextSplitter.""" super().__init__(**kwargs) try: @@ -251,7 +251,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): separators: Optional[list[str]] = None, keep_separator: bool = True, **kwargs: Any, - ) -> None: + ): """Create a new TextSplitter.""" super().__init__(keep_separator=keep_separator, **kwargs) self._separators = separators or ["\n\n", "\n", " ", ""] diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index 384904458..d6f40491b 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -93,7 +93,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository): self._triggered_from, ) - def save(self, execution: WorkflowExecution) -> None: + def save(self, execution: WorkflowExecution): """ Save or update a WorkflowExecution instance asynchronously using Celery. diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index 1c4e6dfe8..b36252dba 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -106,7 +106,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): self._triggered_from, ) - def save(self, execution: WorkflowNodeExecution) -> None: + def save(self, execution: WorkflowNodeExecution): """ Save or update a WorkflowNodeExecution instance to cache and asynchronously to database. diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 74a49842f..46b028b21 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -176,7 +176,7 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): return db_model - def save(self, execution: WorkflowExecution) -> None: + def save(self, execution: WorkflowExecution): """ Save or update a WorkflowExecution domain entity to the database. diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 85754be14..8702af9f8 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -194,9 +194,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) """Check if the exception is a duplicate key constraint violation.""" return isinstance(exception, IntegrityError) and isinstance(exception.orig, psycopg2.errors.UniqueViolation) - def _regenerate_id_on_duplicate( - self, execution: WorkflowNodeExecution, db_model: WorkflowNodeExecutionModel - ) -> None: + def _regenerate_id_on_duplicate(self, execution: WorkflowNodeExecution, db_model: WorkflowNodeExecutionModel): """Regenerate UUID v7 for both domain and database models when duplicate key detected.""" new_id = str(uuidv7()) logger.warning( @@ -205,7 +203,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) db_model.id = new_id execution.id = new_id - def save(self, execution: WorkflowNodeExecution) -> None: + def save(self, execution: WorkflowNodeExecution): """ Save or update a NodeExecution domain entity to the database. @@ -254,7 +252,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) logger.exception("Failed to save workflow node execution after all retries") raise - def _persist_to_database(self, db_model: WorkflowNodeExecutionModel) -> None: + def _persist_to_database(self, db_model: WorkflowNodeExecutionModel): """ Persist the database model to the database. diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index d6961cdaa..5a2b80393 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -20,7 +20,7 @@ class Tool(ABC): The base class of a tool """ - def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None: + def __init__(self, entity: ToolEntity, runtime: ToolRuntime): self.entity = entity self.runtime = runtime diff --git a/api/core/tools/__base/tool_provider.py b/api/core/tools/__base/tool_provider.py index d1d7976cc..49cbf7037 100644 --- a/api/core/tools/__base/tool_provider.py +++ b/api/core/tools/__base/tool_provider.py @@ -12,7 +12,7 @@ from core.tools.errors import ToolProviderCredentialValidationError class ToolProviderController(ABC): - def __init__(self, entity: ToolProviderEntity) -> None: + def __init__(self, entity: ToolProviderEntity): self.entity = entity def get_credentials_schema(self) -> list[ProviderConfig]: @@ -41,7 +41,7 @@ class ToolProviderController(ABC): """ return ToolProviderType.BUILT_IN - def validate_credentials_format(self, credentials: dict[str, Any]) -> None: + def validate_credentials_format(self, credentials: dict[str, Any]): """ validate the format of the credentials of the provider and set the default value if needed diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index 375a32f39..68bfe5b4a 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -24,7 +24,7 @@ from core.tools.utils.yaml_utils import load_yaml_file class BuiltinToolProviderController(ToolProviderController): tools: list[BuiltinTool] - def __init__(self, **data: Any) -> None: + def __init__(self, **data: Any): self.tools = [] # load provider yaml @@ -197,7 +197,7 @@ class BuiltinToolProviderController(ToolProviderController): """ return self.entity.identity.tags or [] - def validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: + def validate_credentials(self, user_id: str, credentials: dict[str, Any]): """ validate the credentials of the provider @@ -211,7 +211,7 @@ class BuiltinToolProviderController(ToolProviderController): self._validate_credentials(user_id, credentials) @abstractmethod - def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): """ validate the credentials of the provider diff --git a/api/core/tools/builtin_tool/providers/audio/audio.py b/api/core/tools/builtin_tool/providers/audio/audio.py index d7d71161f..abf23559e 100644 --- a/api/core/tools/builtin_tool/providers/audio/audio.py +++ b/api/core/tools/builtin_tool/providers/audio/audio.py @@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController class AudioToolProvider(BuiltinToolProviderController): - def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): pass diff --git a/api/core/tools/builtin_tool/providers/code/code.py b/api/core/tools/builtin_tool/providers/code/code.py index 18b7cd4c9..3e02a64e8 100644 --- a/api/core/tools/builtin_tool/providers/code/code.py +++ b/api/core/tools/builtin_tool/providers/code/code.py @@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController class CodeToolProvider(BuiltinToolProviderController): - def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): pass diff --git a/api/core/tools/builtin_tool/providers/time/time.py b/api/core/tools/builtin_tool/providers/time/time.py index 323a7c41b..c8f33ec56 100644 --- a/api/core/tools/builtin_tool/providers/time/time.py +++ b/api/core/tools/builtin_tool/providers/time/time.py @@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController class WikiPediaProvider(BuiltinToolProviderController): - def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): pass diff --git a/api/core/tools/builtin_tool/providers/webscraper/webscraper.py b/api/core/tools/builtin_tool/providers/webscraper/webscraper.py index 52c8370e0..7d8942d42 100644 --- a/api/core/tools/builtin_tool/providers/webscraper/webscraper.py +++ b/api/core/tools/builtin_tool/providers/webscraper/webscraper.py @@ -4,7 +4,7 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController class WebscraperProvider(BuiltinToolProviderController): - def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): """ Validate credentials """ diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py index e3dfa089d..5790aea2b 100644 --- a/api/core/tools/custom_tool/provider.py +++ b/api/core/tools/custom_tool/provider.py @@ -24,7 +24,7 @@ class ApiToolProviderController(ToolProviderController): tenant_id: str tools: list[ApiTool] = Field(default_factory=list) - def __init__(self, entity: ToolProviderEntity, provider_id: str, tenant_id: str) -> None: + def __init__(self, entity: ToolProviderEntity, provider_id: str, tenant_id: str): super().__init__(entity) self.provider_id = provider_id self.tenant_id = tenant_id diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 97342640f..190af999b 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -302,7 +302,7 @@ class ApiTool(Tool): def _convert_body_property_any_of( self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10 - ) -> Any: + ): if max_recursive <= 0: raise Exception("Max recursion depth reached") for option in any_of or []: @@ -337,7 +337,7 @@ class ApiTool(Tool): # If no option succeeded, you might want to return the value as is or raise an error return value # or raise ValueError(f"Cannot convert value '{value}' to any specified type in anyOf") - def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> Any: + def _convert_body_property_type(self, property: dict[str, Any], value: Any): try: if "type" in property: if property["type"] == "integer" or property["type"] == "int": diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 48015c04e..187406fc2 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -49,7 +49,7 @@ class ToolProviderApiEntity(BaseModel): def convert_none_to_empty_list(cls, v): return v if v is not None else [] - def to_dict(self) -> dict: + def to_dict(self): # ------------- # overwrite tool parameter types for temp fix tools = jsonable_encoder(self.tools) @@ -84,7 +84,7 @@ class ToolProviderApiEntity(BaseModel): **optional_fields, } - def optional_field(self, key: str, value: Any) -> dict: + def optional_field(self, key: str, value: Any): """Return dict with key-value if value is truthy, empty dict otherwise.""" return {key: value} if value else {} diff --git a/api/core/tools/entities/common_entities.py b/api/core/tools/entities/common_entities.py index 924e6fc0c..aadbbeb84 100644 --- a/api/core/tools/entities/common_entities.py +++ b/api/core/tools/entities/common_entities.py @@ -19,5 +19,5 @@ class I18nObject(BaseModel): self.pt_BR = self.pt_BR or self.en_US self.ja_JP = self.ja_JP or self.en_US - def to_dict(self) -> dict: + def to_dict(self): return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP} diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index df599a09a..66304b30a 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -150,7 +150,7 @@ class ToolInvokeMessage(BaseModel): @model_validator(mode="before") @classmethod - def transform_variable_value(cls, values) -> Any: + def transform_variable_value(cls, values): """ Only basic types and lists are allowed. """ @@ -428,7 +428,7 @@ class ToolInvokeMeta(BaseModel): """ return cls(time_cost=0.0, error=error, tool_config={}) - def to_dict(self) -> dict: + def to_dict(self): return { "time_cost": self.time_cost, "error": self.error, diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py index 24ee981a1..fa99cccb8 100644 --- a/api/core/tools/mcp_tool/provider.py +++ b/api/core/tools/mcp_tool/provider.py @@ -28,7 +28,7 @@ class MCPToolProviderController(ToolProviderController): headers: Optional[dict[str, str]] = None, timeout: Optional[float] = None, sse_read_timeout: Optional[float] = None, - ) -> None: + ): super().__init__(entity) self.entity: ToolProviderEntityWithPlugin = entity self.tenant_id = tenant_id @@ -99,7 +99,7 @@ class MCPToolProviderController(ToolProviderController): sse_read_timeout=db_provider.sse_read_timeout, ) - def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): """ validate the credentials of the provider """ diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 26789b23c..6810ac683 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -23,7 +23,7 @@ class MCPTool(Tool): headers: Optional[dict[str, str]] = None, timeout: Optional[float] = None, sse_read_timeout: Optional[float] = None, - ) -> None: + ): super().__init__(entity, runtime) self.tenant_id = tenant_id self.icon = icon diff --git a/api/core/tools/plugin_tool/provider.py b/api/core/tools/plugin_tool/provider.py index 494b8e209..3fbbd4c9e 100644 --- a/api/core/tools/plugin_tool/provider.py +++ b/api/core/tools/plugin_tool/provider.py @@ -16,7 +16,7 @@ class PluginToolProviderController(BuiltinToolProviderController): def __init__( self, entity: ToolProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str - ) -> None: + ): self.entity = entity self.tenant_id = tenant_id self.plugin_id = plugin_id @@ -31,7 +31,7 @@ class PluginToolProviderController(BuiltinToolProviderController): """ return ToolProviderType.PLUGIN - def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): """ validate the credentials of the provider """ diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index db38c10e8..e649caec1 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -11,7 +11,7 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too class PluginTool(Tool): def __init__( self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str - ) -> None: + ): super().__init__(entity, runtime) self.tenant_id = tenant_id self.icon = icon diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 9897045d9..834f58be6 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -778,7 +778,7 @@ class ToolManager: return controller @classmethod - def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: + def user_get_api_provider(cls, provider: str, tenant_id: str): """ get api provider """ @@ -873,7 +873,7 @@ class ToolManager: ) @classmethod - def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict: + def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str): try: workflow_provider: WorkflowToolProvider | None = ( db.session.query(WorkflowToolProvider) @@ -890,7 +890,7 @@ class ToolManager: return {"background": "#252525", "content": "\ud83d\ude01"} @classmethod - def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict: + def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str): try: api_provider: ApiToolProvider | None = ( db.session.query(ApiToolProvider) diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 3a9391dbb..3ac487a47 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -24,7 +24,7 @@ class ToolParameterConfigurationManager: def __init__( self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str - ) -> None: + ): self.tenant_id = tenant_id self.tool_runtime = tool_runtime self.provider_name = provider_name diff --git a/api/core/tools/utils/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever_tool.py index d58807e29..d5803e33e 100644 --- a/api/core/tools/utils/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever_tool.py @@ -20,7 +20,7 @@ from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import Datas class DatasetRetrieverTool(Tool): - def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None: + def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool): super().__init__(entity, runtime) self.retrieval_tool = retrieval_tool diff --git a/api/core/tools/utils/encryption.py b/api/core/tools/utils/encryption.py index d771293e1..5820be0ff 100644 --- a/api/core/tools/utils/encryption.py +++ b/api/core/tools/utils/encryption.py @@ -17,11 +17,11 @@ class ProviderConfigCache(Protocol): """Get cached provider configuration""" ... - def set(self, config: dict[str, Any]) -> None: + def set(self, config: dict[str, Any]): """Cache provider configuration""" ... - def delete(self) -> None: + def delete(self): """Delete cached provider configuration""" ... diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 78f1f339f..cae21633f 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -242,7 +242,7 @@ class ApiBasedToolSchemaParser: return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) @staticmethod - def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict: + def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None): warning = warning or {} """ parse swagger to openapi diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index ee7ca11e0..8a0a91a50 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -8,7 +8,7 @@ from yaml import YAMLError logger = logging.getLogger(__name__) -def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any: +def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}): """ Safe loading a YAML file :param file_path: the path of the YAML file diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 9bcc63952..c9d62388f 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -223,7 +223,7 @@ class WorkflowTool(Tool): return result, files - def _update_file_mapping(self, file_dict: dict) -> dict: + def _update_file_mapping(self, file_dict: dict): transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method")) if transfer_method == FileTransferMethod.TOOL_FILE: file_dict["tool_file_id"] = file_dict.get("related_id") diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 9e7616874..cfef19363 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -51,7 +51,7 @@ class Segment(BaseModel): """ return sys.getsizeof(self.value) - def to_object(self) -> Any: + def to_object(self): return self.value diff --git a/api/core/variables/types.py b/api/core/variables/types.py index 55f8ae3c7..a2e12e742 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -159,7 +159,7 @@ class SegmentType(StrEnum): raise AssertionError("this statement should be unreachable.") @staticmethod - def cast_value(value: Any, type_: "SegmentType") -> Any: + def cast_value(value: Any, type_: "SegmentType"): # Cast Python's `bool` type to `int` when the runtime type requires # an integer or number. # diff --git a/api/core/variables/utils.py b/api/core/variables/utils.py index 7ebd29f86..8e738f8fd 100644 --- a/api/core/variables/utils.py +++ b/api/core/variables/utils.py @@ -14,7 +14,7 @@ def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[ return selectors -def segment_orjson_default(o: Any) -> Any: +def segment_orjson_default(o: Any): """Default function for orjson serialization of Segment types""" if isinstance(o, ArrayFileSegment): return [v.model_dump() for v in o.value] diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index 83086d1af..5f1372c65 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -5,7 +5,7 @@ from core.workflow.graph_engine.entities.event import GraphEngineEvent class WorkflowCallback(ABC): @abstractmethod - def on_event(self, event: GraphEngineEvent) -> None: + def on_event(self, event: GraphEngineEvent): """ Published event """ diff --git a/api/core/workflow/callbacks/workflow_logging_callback.py b/api/core/workflow/callbacks/workflow_logging_callback.py index 12b5203ca..ec62be605 100644 --- a/api/core/workflow/callbacks/workflow_logging_callback.py +++ b/api/core/workflow/callbacks/workflow_logging_callback.py @@ -36,10 +36,10 @@ _TEXT_COLOR_MAPPING = { class WorkflowLoggingCallback(WorkflowCallback): - def __init__(self) -> None: + def __init__(self): self.current_node_id: Optional[str] = None - def on_event(self, event: GraphEngineEvent) -> None: + def on_event(self, event: GraphEngineEvent): if isinstance(event, GraphRunStartedEvent): self.print_text("\n[GraphRunStartedEvent]", color="pink") elif isinstance(event, GraphRunSucceededEvent): @@ -75,7 +75,7 @@ class WorkflowLoggingCallback(WorkflowCallback): else: self.print_text(f"\n[{event.__class__.__name__}]", color="blue") - def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None: + def on_workflow_node_execute_started(self, event: NodeRunStartedEvent): """ Workflow node execute started """ @@ -84,7 +84,7 @@ class WorkflowLoggingCallback(WorkflowCallback): self.print_text(f"Node Title: {event.node_data.title}", color="yellow") self.print_text(f"Type: {event.node_type.value}", color="yellow") - def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None: + def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent): """ Workflow node execute succeeded """ @@ -115,7 +115,7 @@ class WorkflowLoggingCallback(WorkflowCallback): color="green", ) - def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None: + def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent): """ Workflow node execute failed """ @@ -143,7 +143,7 @@ class WorkflowLoggingCallback(WorkflowCallback): color="red", ) - def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None: + def on_node_text_chunk(self, event: NodeRunStreamChunkEvent): """ Publish text chunk """ @@ -161,7 +161,7 @@ class WorkflowLoggingCallback(WorkflowCallback): self.print_text(event.chunk_content, color="pink", end="") - def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None: + def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent): """ Publish parallel started """ @@ -173,9 +173,7 @@ class WorkflowLoggingCallback(WorkflowCallback): if event.in_loop_id: self.print_text(f"Loop ID: {event.in_loop_id}", color="blue") - def on_workflow_parallel_completed( - self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent - ) -> None: + def on_workflow_parallel_completed(self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent): """ Publish parallel completed """ @@ -200,14 +198,14 @@ class WorkflowLoggingCallback(WorkflowCallback): if isinstance(event, ParallelBranchRunFailedEvent): self.print_text(f"Error: {event.error}", color=color) - def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None: + def on_workflow_iteration_started(self, event: IterationRunStartedEvent): """ Publish iteration started """ self.print_text("\n[IterationRunStartedEvent]", color="blue") self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue") - def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None: + def on_workflow_iteration_next(self, event: IterationRunNextEvent): """ Publish iteration next """ @@ -215,7 +213,7 @@ class WorkflowLoggingCallback(WorkflowCallback): self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue") self.print_text(f"Iteration Index: {event.index}", color="blue") - def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None: + def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent): """ Publish iteration completed """ @@ -227,14 +225,14 @@ class WorkflowLoggingCallback(WorkflowCallback): ) self.print_text(f"Node ID: {event.iteration_id}", color="blue") - def on_workflow_loop_started(self, event: LoopRunStartedEvent) -> None: + def on_workflow_loop_started(self, event: LoopRunStartedEvent): """ Publish loop started """ self.print_text("\n[LoopRunStartedEvent]", color="blue") self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue") - def on_workflow_loop_next(self, event: LoopRunNextEvent) -> None: + def on_workflow_loop_next(self, event: LoopRunNextEvent): """ Publish loop next """ @@ -242,7 +240,7 @@ class WorkflowLoggingCallback(WorkflowCallback): self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue") self.print_text(f"Loop Index: {event.index}", color="blue") - def on_workflow_loop_completed(self, event: LoopRunSucceededEvent | LoopRunFailedEvent) -> None: + def on_workflow_loop_completed(self, event: LoopRunSucceededEvent | LoopRunFailedEvent): """ Publish loop completed """ @@ -252,7 +250,7 @@ class WorkflowLoggingCallback(WorkflowCallback): ) self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue") - def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None: + def print_text(self, text: str, color: Optional[str] = None, end: str = "\n"): """Print text with highlighting and no end characters.""" text_to_print = self._get_colored_text(text, color) if color else text print(f"{text_to_print}", end=end) diff --git a/api/core/workflow/conversation_variable_updater.py b/api/core/workflow/conversation_variable_updater.py index 84e99bb58..fd78248c1 100644 --- a/api/core/workflow/conversation_variable_updater.py +++ b/api/core/workflow/conversation_variable_updater.py @@ -20,7 +20,7 @@ class ConversationVariableUpdater(Protocol): """ @abc.abstractmethod - def update(self, conversation_id: str, variable: "Variable") -> None: + def update(self, conversation_id: str, variable: "Variable"): """ Updates the value of the specified conversation variable in the underlying storage. diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index fb0794844..a2c13fcbf 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -47,7 +47,7 @@ class VariablePool(BaseModel): default_factory=list, ) - def model_post_init(self, context: Any, /) -> None: + def model_post_init(self, context: Any, /): # Create a mapping from field names to SystemVariableKey enum values self._add_system_variables(self.system_variables) # Add environment variables to the variable pool @@ -57,7 +57,7 @@ class VariablePool(BaseModel): for var in self.conversation_variables: self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var) - def add(self, selector: Sequence[str], value: Any, /) -> None: + def add(self, selector: Sequence[str], value: Any, /): """ Add a variable to the variable pool. @@ -161,11 +161,11 @@ class VariablePool(BaseModel): # Return result as Segment return result if isinstance(result, Segment) else variable_factory.build_segment(result) - def _extract_value(self, obj: Any) -> Any: + def _extract_value(self, obj: Any): """Extract the actual value from an ObjectSegment.""" return obj.value if isinstance(obj, ObjectSegment) else obj - def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Any: + def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str): """Get a nested attribute from a dictionary-like object.""" if not isinstance(obj, dict): return None diff --git a/api/core/workflow/entities/workflow_node_execution.py b/api/core/workflow/entities/workflow_node_execution.py index 09a408f4d..ff72d7cbf 100644 --- a/api/core/workflow/entities/workflow_node_execution.py +++ b/api/core/workflow/entities/workflow_node_execution.py @@ -112,7 +112,7 @@ class WorkflowNodeExecution(BaseModel): process_data: Optional[Mapping[str, Any]] = None, outputs: Optional[Mapping[str, Any]] = None, metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None, - ) -> None: + ): """ Update the model from mappings. diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 49984806c..d8d1825d9 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -205,9 +205,7 @@ class Graph(BaseModel): return graph @classmethod - def _recursively_add_node_ids( - cls, node_ids: list[str], edge_mapping: dict[str, list[GraphEdge]], node_id: str - ) -> None: + def _recursively_add_node_ids(cls, node_ids: list[str], edge_mapping: dict[str, list[GraphEdge]], node_id: str): """ Recursively add node ids @@ -225,7 +223,7 @@ class Graph(BaseModel): ) @classmethod - def _check_connected_to_previous_node(cls, route: list[str], edge_mapping: dict[str, list[GraphEdge]]) -> None: + def _check_connected_to_previous_node(cls, route: list[str], edge_mapping: dict[str, list[GraphEdge]]): """ Check whether it is connected to the previous node """ @@ -256,7 +254,7 @@ class Graph(BaseModel): parallel_mapping: dict[str, GraphParallel], node_parallel_mapping: dict[str, str], parent_parallel: Optional[GraphParallel] = None, - ) -> None: + ): """ Recursively add parallel ids @@ -461,7 +459,7 @@ class Graph(BaseModel): level_limit: int, parent_parallel_id: str, current_level: int = 1, - ) -> None: + ): """ Check if it exceeds N layers of parallel """ @@ -488,7 +486,7 @@ class Graph(BaseModel): edge_mapping: dict[str, list[GraphEdge]], merge_node_id: str, start_node_id: str, - ) -> None: + ): """ Recursively add node ids @@ -614,7 +612,7 @@ class Graph(BaseModel): @classmethod def _recursively_fetch_routes( cls, edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: list[str] - ) -> None: + ): """ Recursively fetch route """ diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index a4ddfafab..54440df72 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -47,7 +47,7 @@ class RouteNodeState(BaseModel): index: int = 1 - def set_finished(self, run_result: NodeRunResult) -> None: + def set_finished(self, run_result: NodeRunResult): """ Node finished @@ -94,7 +94,7 @@ class RuntimeRouteState(BaseModel): self.node_state_mapping[state.id] = state return state - def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None: + def add_route(self, source_node_state_id: str, target_node_state_id: str): """ Add route to the graph state diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 188d0c475..9b0b187a7 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -66,7 +66,7 @@ class GraphEngineThreadPool(ThreadPoolExecutor): initializer=None, initargs=(), max_submit_count=dify_config.MAX_SUBMIT_COUNT, - ) -> None: + ): super().__init__(max_workers, thread_name_prefix, initializer, initargs) self.max_submit_count = max_submit_count self.submit_count = 0 @@ -80,7 +80,7 @@ class GraphEngineThreadPool(ThreadPoolExecutor): def task_done_callback(self, future): self.submit_count -= 1 - def check_is_full(self) -> None: + def check_is_full(self): if self.submit_count > self.max_submit_count: raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.") @@ -104,7 +104,7 @@ class GraphEngine: max_execution_steps: int, max_execution_time: int, thread_pool_id: Optional[str] = None, - ) -> None: + ): thread_pool_max_submit_count = dify_config.MAX_SUBMIT_COUNT thread_pool_max_workers = 10 @@ -537,7 +537,7 @@ class GraphEngine: parent_parallel_id: Optional[str] = None, parent_parallel_start_node_id: Optional[str] = None, handle_exceptions: list[str] = [], - ) -> None: + ): """ Run parallel nodes """ diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 9e5d5e62b..2e5912652 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -66,7 +66,7 @@ class AgentNode(BaseNode): _node_type = NodeType.AGENT _node_data: AgentNodeData - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = AgentNodeData.model_validate(data) def _get_error_strategy(self) -> Optional[ErrorStrategy]: diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 84bbabca7..116250c5c 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -22,7 +22,7 @@ class AnswerNode(BaseNode): _node_data: AnswerNodeData - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = AnswerNodeData.model_validate(data) def _get_error_strategy(self) -> Optional[ErrorStrategy]: diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index 1d9c3e9b9..216fe9b67 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -134,7 +134,7 @@ class AnswerStreamGeneratorRouter: node_id_config_mapping: dict[str, dict], reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] answer_dependencies: dict[str, list[str]], - ) -> None: + ): """ Recursive fetch answer dependencies :param current_node_id: current node id diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index a30014299..2b1070f5e 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) class AnswerStreamProcessor(StreamProcessor): - def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: + def __init__(self, graph: Graph, variable_pool: VariablePool): super().__init__(graph, variable_pool) self.generate_routes = graph.answer_stream_generate_routes self.route_position = {} @@ -66,7 +66,7 @@ class AnswerStreamProcessor(StreamProcessor): else: yield event - def reset(self) -> None: + def reset(self): self.route_position = {} for answer_node_id, _ in self.generate_routes.answer_generate_route.items(): self.route_position[answer_node_id] = 0 diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index 7e84557a2..9e8e1787e 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) class StreamProcessor(ABC): - def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: + def __init__(self, graph: Graph, variable_pool: VariablePool): self.graph = graph self.variable_pool = variable_pool self.rest_node_ids = graph.node_ids.copy() @@ -20,7 +20,7 @@ class StreamProcessor(ABC): def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: raise NotImplementedError - def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None: + def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent): finished_node_id = event.route_node_state.node_id if finished_node_id not in self.rest_node_ids: return @@ -89,7 +89,7 @@ class StreamProcessor(ABC): node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify)) return node_ids - def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None: + def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]): """ remove target node ids until merge """ diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index dcfed5eed..708da2117 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -28,7 +28,7 @@ class DefaultValue(BaseModel): key: str @staticmethod - def _parse_json(value: str) -> Any: + def _parse_json(value: str): """Unified JSON parsing handler""" try: return json.loads(value) diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index be4f79af1..3aee9b2cc 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -28,7 +28,7 @@ class BaseNode: graph_runtime_state: "GraphRuntimeState", previous_node_id: Optional[str] = None, thread_pool_id: Optional[str] = None, - ) -> None: + ): self.id = id self.tenant_id = graph_init_params.tenant_id self.app_id = graph_init_params.app_id @@ -51,7 +51,7 @@ class BaseNode: self.node_id = node_id @abstractmethod - def init_node_data(self, data: Mapping[str, Any]) -> None: ... + def init_node_data(self, data: Mapping[str, Any]): ... @abstractmethod def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]: @@ -141,7 +141,7 @@ class BaseNode: return {} @classmethod - def get_default_config(cls, filters: Optional[dict] = None) -> dict: + def get_default_config(cls, filters: Optional[dict] = None): return {} @property diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 17bd841fc..d32d86865 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -28,7 +28,7 @@ class CodeNode(BaseNode): _node_data: CodeNodeData - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = CodeNodeData.model_validate(data) def _get_error_strategy(self) -> Optional[ErrorStrategy]: @@ -50,7 +50,7 @@ class CodeNode(BaseNode): return self._node_data @classmethod - def get_default_config(cls, filters: Optional[dict] = None) -> dict: + def get_default_config(cls, filters: Optional[dict] = None): """ Get default config of node. :param filters: filter by node config parameters. diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 125b84501..7848bab44 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -47,7 +47,7 @@ class DocumentExtractorNode(BaseNode): _node_data: DocumentExtractorNodeData - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = DocumentExtractorNodeData.model_validate(data) def _get_error_strategy(self) -> Optional[ErrorStrategy]: diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index f86f2e812..0aff039e9 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -14,7 +14,7 @@ class EndNode(BaseNode): _node_data: EndNodeData - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = EndNodeData(**data) def _get_error_strategy(self) -> Optional[ErrorStrategy]: diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py index b3678a82b..495ed6ea2 100644 --- a/api/core/workflow/nodes/end/end_stream_generate_router.py +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -121,7 +121,7 @@ class EndStreamGeneratorRouter: node_id_config_mapping: dict[str, dict], reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] end_dependencies: dict[str, list[str]], - ) -> None: + ): """ Recursive fetch end dependencies :param current_node_id: current node id diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py index a6fb2ffc1..7e426fee7 100644 --- a/api/core/workflow/nodes/end/end_stream_processor.py +++ b/api/core/workflow/nodes/end/end_stream_processor.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) class EndStreamProcessor(StreamProcessor): - def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: + def __init__(self, graph: Graph, variable_pool: VariablePool): super().__init__(graph, variable_pool) self.end_stream_param = graph.end_stream_param self.route_position = {} @@ -76,7 +76,7 @@ class EndStreamProcessor(StreamProcessor): else: yield event - def reset(self) -> None: + def reset(self): self.route_position = {} for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items(): self.route_position[end_node_id] = 0 diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index bc1d5c9b8..bb3c453d9 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -38,7 +38,7 @@ class HttpRequestNode(BaseNode): _node_data: HttpRequestNodeData - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = HttpRequestNodeData.model_validate(data) def _get_error_strategy(self) -> Optional[ErrorStrategy]: @@ -60,7 +60,7 @@ class HttpRequestNode(BaseNode): return self._node_data @classmethod - def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict: + def get_default_config(cls, filters: Optional[dict[str, Any]] = None): return { "type": "http-request", "config": { diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index c2bed870b..82dba59cb 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -19,7 +19,7 @@ class IfElseNode(BaseNode): _node_data: IfElseNodeData - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = IfElseNodeData.model_validate(data) def _get_error_strategy(self) -> Optional[ErrorStrategy]: diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 9deac1748..9037677df 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -67,7 +67,7 @@ class IterationNode(BaseNode): _node_data: IterationNodeData - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = IterationNodeData.model_validate(data) def _get_error_strategy(self) -> Optional[ErrorStrategy]: @@ -89,7 +89,7 @@ class IterationNode(BaseNode): return self._node_data @classmethod - def get_default_config(cls, filters: Optional[dict] = None) -> dict: + def get_default_config(cls, filters: Optional[dict] = None): return { "type": "iteration", "config": { diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index b82c29291..8c4794cf3 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -18,7 +18,7 @@ class IterationStartNode(BaseNode): _node_data: IterationStartNodeData - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = IterationStartNodeData(**data) def _get_error_strategy(self) -> Optional[ErrorStrategy]: diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 949a05d05..d357fea7d 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -105,7 +105,7 @@ class KnowledgeRetrievalNode(BaseNode): thread_pool_id: Optional[str] = None, *, llm_file_saver: LLMFileSaver | None = None, - ) -> None: + ): super().__init__( id=id, config=config, @@ -125,7 +125,7 @@ class KnowledgeRetrievalNode(BaseNode): ) self._llm_file_saver = llm_file_saver - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = KnowledgeRetrievalNodeData.model_validate(data) def _get_error_strategy(self) -> Optional[ErrorStrategy]: diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index a727a826c..eb7b9fc2c 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -41,7 +41,7 @@ class ListOperatorNode(BaseNode): _node_data: ListOperatorNodeData - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = ListOperatorNodeData(**data) def _get_error_strategy(self) -> Optional[ErrorStrategy]: diff --git a/api/core/workflow/nodes/llm/exc.py b/api/core/workflow/nodes/llm/exc.py index 42b8f4e6c..4d1609529 100644 --- a/api/core/workflow/nodes/llm/exc.py +++ b/api/core/workflow/nodes/llm/exc.py @@ -41,5 +41,5 @@ class FileTypeNotSupportError(LLMNodeError): class UnsupportedPromptContentTypeError(LLMNodeError): - def __init__(self, *, type_name: str) -> None: + def __init__(self, *, type_name: str): super().__init__(f"Prompt content type {type_name} is not supported.") diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index 2441e30c8..fae127ab7 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -107,7 +107,7 @@ def fetch_memory( return memory -def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: +def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage): provider_model_bundle = model_instance.provider_model_bundle provider_configuration = provider_model_bundle.configuration diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 37c4ecfd6..c34a06d98 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -120,7 +120,7 @@ class LLMNode(BaseNode): thread_pool_id: Optional[str] = None, *, llm_file_saver: LLMFileSaver | None = None, - ) -> None: + ): super().__init__( id=id, config=config, @@ -140,7 +140,7 @@ class LLMNode(BaseNode): ) self._llm_file_saver = llm_file_saver - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = LLMNodeData.model_validate(data) def _get_error_strategy(self) -> Optional[ErrorStrategy]: @@ -951,7 +951,7 @@ class LLMNode(BaseNode): return variable_mapping @classmethod - def get_default_config(cls, filters: Optional[dict] = None) -> dict: + def get_default_config(cls, filters: Optional[dict] = None): return { "type": "llm", "config": { diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py index 53cadc525..892ae88b0 100644 --- a/api/core/workflow/nodes/loop/loop_end_node.py +++ b/api/core/workflow/nodes/loop/loop_end_node.py @@ -18,7 +18,7 @@ class LoopEndNode(BaseNode): _node_data: LoopEndNodeData - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = LoopEndNodeData(**data) def _get_error_strategy(self) -> Optional[ErrorStrategy]: diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index ae3927a89..2fe3fb556 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -54,7 +54,7 @@ class LoopNode(BaseNode): _node_data: LoopNodeData - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = LoopNodeData.model_validate(data) def _get_error_strategy(self) -> Optional[ErrorStrategy]: diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/core/workflow/nodes/loop/loop_start_node.py index 29b45ea0c..f5a20fc00 100644 --- a/api/core/workflow/nodes/loop/loop_start_node.py +++ b/api/core/workflow/nodes/loop/loop_start_node.py @@ -18,7 +18,7 @@ class LoopStartNode(BaseNode): _node_data: LoopStartNodeData - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = LoopStartNodeData(**data) def _get_error_strategy(self) -> Optional[ErrorStrategy]: diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 4e93cd968..273914022 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -98,7 +98,7 @@ class ParameterExtractorNodeData(BaseNodeData): def set_reasoning_mode(cls, v) -> str: return v or "function_call" - def get_parameter_json_schema(self) -> dict: + def get_parameter_json_schema(self): """ Get parameter json schema. diff --git a/api/core/workflow/nodes/parameter_extractor/exc.py b/api/core/workflow/nodes/parameter_extractor/exc.py index 247518cf2..a1707a246 100644 --- a/api/core/workflow/nodes/parameter_extractor/exc.py +++ b/api/core/workflow/nodes/parameter_extractor/exc.py @@ -63,7 +63,7 @@ class InvalidValueTypeError(ParameterExtractorNodeError): expected_type: SegmentType, actual_type: SegmentType | None, value: Any, - ) -> None: + ): message = ( f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, " f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}" diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 43edf7eac..a854c7e87 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -94,7 +94,7 @@ class ParameterExtractorNode(BaseNode): _node_data: ParameterExtractorNodeData - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = ParameterExtractorNodeData.model_validate(data) def _get_error_strategy(self) -> Optional[ErrorStrategy]: @@ -119,7 +119,7 @@ class ParameterExtractorNode(BaseNode): _model_config: Optional[ModelConfigWithCredentialsEntity] = None @classmethod - def get_default_config(cls, filters: Optional[dict] = None) -> dict: + def get_default_config(cls, filters: Optional[dict] = None): return { "model": { "prompt_templates": { @@ -545,7 +545,7 @@ class ParameterExtractorNode(BaseNode): return prompt_messages - def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict: + def _validate_result(self, data: ParameterExtractorNodeData, result: dict): if len(data.parameters) != len(result): raise InvalidNumberOfParametersError("Invalid number of parameters") @@ -597,7 +597,7 @@ class ParameterExtractorNode(BaseNode): except ValueError: return None - def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict: + def _transform_result(self, data: ParameterExtractorNodeData, result: dict): """ Transform result into standard format. """ @@ -690,7 +690,7 @@ class ParameterExtractorNode(BaseNode): logger.info("extra error: %s", result) return None - def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict: + def _generate_default_result(self, data: ParameterExtractorNodeData): """ Generate default result. """ diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index ba4e55bb8..07fb658f2 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -63,7 +63,7 @@ class QuestionClassifierNode(BaseNode): thread_pool_id: Optional[str] = None, *, llm_file_saver: LLMFileSaver | None = None, - ) -> None: + ): super().__init__( id=id, config=config, @@ -83,7 +83,7 @@ class QuestionClassifierNode(BaseNode): ) self._llm_file_saver = llm_file_saver - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = QuestionClassifierNodeData.model_validate(data) def _get_error_strategy(self) -> Optional[ErrorStrategy]: @@ -275,7 +275,7 @@ class QuestionClassifierNode(BaseNode): return variable_mapping @classmethod - def get_default_config(cls, filters: Optional[dict] = None) -> dict: + def get_default_config(cls, filters: Optional[dict] = None): """ Get default config of node. :param filters: filter by node config parameters. diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 9e401e76b..6052774e6 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -15,7 +15,7 @@ class StartNode(BaseNode): _node_data: StartNodeData - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = StartNodeData(**data) def _get_error_strategy(self) -> Optional[ErrorStrategy]: diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 1962c82db..5588463a3 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -18,7 +18,7 @@ class TemplateTransformNode(BaseNode): _node_data: TemplateTransformNodeData - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = TemplateTransformNodeData.model_validate(data) def _get_error_strategy(self) -> Optional[ErrorStrategy]: @@ -40,7 +40,7 @@ class TemplateTransformNode(BaseNode): return self._node_data @classmethod - def get_default_config(cls, filters: Optional[dict] = None) -> dict: + def get_default_config(cls, filters: Optional[dict] = None): """ Get default config of node. :param filters: filter by node config parameters. diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 1a85c08b5..c4caf5d83 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -45,7 +45,7 @@ class ToolNode(BaseNode): _node_data: ToolNodeData - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = ToolNodeData.model_validate(data) @classmethod diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index 98127bbeb..cc5092d0a 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -15,7 +15,7 @@ class VariableAggregatorNode(BaseNode): _node_data: VariableAssignerNodeData - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = VariableAssignerNodeData(**data) def _get_error_strategy(self) -> Optional[ErrorStrategy]: diff --git a/api/core/workflow/nodes/variable_assigner/common/impl.py b/api/core/workflow/nodes/variable_assigner/common/impl.py index 8f7a44bb6..5292a9e44 100644 --- a/api/core/workflow/nodes/variable_assigner/common/impl.py +++ b/api/core/workflow/nodes/variable_assigner/common/impl.py @@ -11,7 +11,7 @@ from .exc import VariableOperatorNodeError class ConversationVariableUpdaterImpl: _engine: Engine | None - def __init__(self, engine: Engine | None = None) -> None: + def __init__(self, engine: Engine | None = None): self._engine = engine def _get_engine(self) -> Engine: diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 321d280b1..263c5a389 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -30,7 +30,7 @@ class VariableAssignerNode(BaseNode): _node_data: VariableAssignerData - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = VariableAssignerData.model_validate(data) def _get_error_strategy(self) -> Optional[ErrorStrategy]: @@ -61,7 +61,7 @@ class VariableAssignerNode(BaseNode): previous_node_id: Optional[str] = None, thread_pool_id: Optional[str] = None, conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory, - ) -> None: + ): super().__init__( id=id, config=config, diff --git a/api/core/workflow/nodes/variable_assigner/v2/exc.py b/api/core/workflow/nodes/variable_assigner/v2/exc.py index fd6c304a9..05173b3ca 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/exc.py +++ b/api/core/workflow/nodes/variable_assigner/v2/exc.py @@ -32,5 +32,5 @@ class ConversationIDNotFoundError(VariableOperatorNodeError): class InvalidDataError(VariableOperatorNodeError): - def __init__(self, message: str) -> None: + def __init__(self, message: str): super().__init__(message) diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index 00ee921ce..fdd155adf 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -58,7 +58,7 @@ class VariableAssignerNode(BaseNode): _node_data: VariableAssignerNodeData - def init_node_data(self, data: Mapping[str, Any]) -> None: + def init_node_data(self, data: Mapping[str, Any]): self._node_data = VariableAssignerNodeData.model_validate(data) def _get_error_strategy(self) -> Optional[ErrorStrategy]: diff --git a/api/core/workflow/repositories/workflow_execution_repository.py b/api/core/workflow/repositories/workflow_execution_repository.py index bcbd25339..1e2bd79c7 100644 --- a/api/core/workflow/repositories/workflow_execution_repository.py +++ b/api/core/workflow/repositories/workflow_execution_repository.py @@ -16,7 +16,7 @@ class WorkflowExecutionRepository(Protocol): application domains or deployment scenarios. """ - def save(self, execution: WorkflowExecution) -> None: + def save(self, execution: WorkflowExecution): """ Save or update a WorkflowExecution instance. diff --git a/api/core/workflow/repositories/workflow_node_execution_repository.py b/api/core/workflow/repositories/workflow_node_execution_repository.py index 8bf81f544..f4668c05c 100644 --- a/api/core/workflow/repositories/workflow_node_execution_repository.py +++ b/api/core/workflow/repositories/workflow_node_execution_repository.py @@ -26,7 +26,7 @@ class WorkflowNodeExecutionRepository(Protocol): application domains or deployment scenarios. """ - def save(self, execution: WorkflowNodeExecution) -> None: + def save(self, execution: WorkflowNodeExecution): """ Save or update a NodeExecution instance. diff --git a/api/core/workflow/utils/variable_template_parser.py b/api/core/workflow/utils/variable_template_parser.py index f86c54c50..a6dd98db5 100644 --- a/api/core/workflow/utils/variable_template_parser.py +++ b/api/core/workflow/utils/variable_template_parser.py @@ -57,7 +57,7 @@ class VariableTemplateParser: self.template = template self.variable_keys = self.extract() - def extract(self) -> list: + def extract(self): """ Extracts all the template variable keys from the template string. diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index 3c264e782..4f259b64a 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -48,7 +48,7 @@ class WorkflowCycleManager: workflow_info: CycleManagerWorkflowInfo, workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, - ) -> None: + ): self._application_generate_entity = application_generate_entity self._workflow_system_variables = workflow_system_variables self._workflow_info = workflow_info @@ -299,7 +299,7 @@ class WorkflowCycleManager: error_message: Optional[str] = None, exceptions_count: int = 0, finished_at: Optional[datetime] = None, - ) -> None: + ): """Update workflow execution with completion data.""" execution.status = status execution.outputs = outputs or {} @@ -316,7 +316,7 @@ class WorkflowCycleManager: workflow_execution: WorkflowExecution, conversation_id: Optional[str], external_trace_id: Optional[str], - ) -> None: + ): """Add trace task if trace manager is provided.""" if trace_manager: trace_manager.add_trace_task( @@ -334,7 +334,7 @@ class WorkflowCycleManager: workflow_execution_id: str, error_message: str, now: datetime, - ) -> None: + ): """Fail all running node executions for a workflow.""" running_node_executions = [ node_exec @@ -406,7 +406,7 @@ class WorkflowCycleManager: status: WorkflowNodeExecutionStatus, error: Optional[str] = None, handle_special_values: bool = False, - ) -> None: + ): """Update node execution with completion data.""" finished_at = naive_utc_now() elapsed_time = (finished_at - event.start_at).total_seconds() diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index e9b73df0f..b69a9971b 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -48,7 +48,7 @@ class WorkflowEntry: call_depth: int, variable_pool: VariablePool, thread_pool_id: Optional[str] = None, - ) -> None: + ): """ Init workflow entry :param tenant_id: tenant id @@ -320,7 +320,7 @@ class WorkflowEntry: return result if isinstance(result, Mapping) or result is None else dict(result) @staticmethod - def _handle_special_values(value: Any) -> Any: + def _handle_special_values(value: Any): if value is None: return value if isinstance(value, dict): @@ -345,7 +345,7 @@ class WorkflowEntry: user_inputs: Mapping[str, Any], variable_pool: VariablePool, tenant_id: str, - ) -> None: + ): # NOTE(QuantumGhost): This logic should remain synchronized with # the implementation of `load_into_variable_pool`, specifically the logic about # variable existence checking. diff --git a/api/core/workflow/workflow_type_encoder.py b/api/core/workflow/workflow_type_encoder.py index 08e12e268..6eac2dd6b 100644 --- a/api/core/workflow/workflow_type_encoder.py +++ b/api/core/workflow/workflow_type_encoder.py @@ -13,7 +13,7 @@ class WorkflowRuntimeTypeConverter: result = self._to_json_encodable_recursive(value) return result if isinstance(result, Mapping) or result is None else dict(result) - def _to_json_encodable_recursive(self, value: Any) -> Any: + def _to_json_encodable_recursive(self, value: Any): if value is None: return value if isinstance(value, (bool, int, str, float)): diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py index 5b6c1511c..21fd0b9c5 100644 --- a/api/events/event_handlers/update_provider_when_message_created.py +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -42,7 +42,7 @@ def _get_last_update_timestamp(cache_key: str) -> Optional[datetime]: @redis_fallback() -def _set_last_update_timestamp(cache_key: str, timestamp: datetime) -> None: +def _set_last_update_timestamp(cache_key: str, timestamp: datetime): """Set last update timestamp in Redis cache with TTL.""" redis_client.setex(cache_key, _CACHE_TTL_SECONDS, str(timestamp.timestamp())) diff --git a/api/extensions/ext_database.py b/api/extensions/ext_database.py index b32616b17..db16b6096 100644 --- a/api/extensions/ext_database.py +++ b/api/extensions/ext_database.py @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) _GEVENT_COMPATIBILITY_SETUP: bool = False -def _safe_rollback(connection) -> None: +def _safe_rollback(connection): """Safely rollback database connection. Args: @@ -25,7 +25,7 @@ def _safe_rollback(connection) -> None: logger.exception("Failed to rollback connection") -def _setup_gevent_compatibility() -> None: +def _setup_gevent_compatibility(): global _GEVENT_COMPATIBILITY_SETUP # pylint: disable=global-statement # Avoid duplicate registration @@ -33,7 +33,7 @@ def _setup_gevent_compatibility() -> None: return @event.listens_for(Pool, "reset") - def _safe_reset(dbapi_connection, connection_record, reset_state) -> None: # pylint: disable=unused-argument + def _safe_reset(dbapi_connection, connection_record, reset_state): # pylint: disable=unused-argument if reset_state.terminate_only: return diff --git a/api/extensions/ext_orjson.py b/api/extensions/ext_orjson.py index 659784a58..efa1386a6 100644 --- a/api/extensions/ext_orjson.py +++ b/api/extensions/ext_orjson.py @@ -3,6 +3,6 @@ from flask_orjson import OrjsonProvider from dify_app import DifyApp -def init_app(app: DifyApp) -> None: +def init_app(app: DifyApp): """Initialize Flask-Orjson extension for faster JSON serialization""" app.json = OrjsonProvider(app) diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py index adf2944c9..33fa7d0a8 100644 --- a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -40,7 +40,7 @@ class ClickZettaVolumeConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict): """Validate the configuration values. This method will first try to use CLICKZETTA_VOLUME_* environment variables, @@ -217,7 +217,7 @@ class ClickZettaVolumeStorage(BaseStorage): logger.exception("SQL execution failed: %s", sql) raise - def _ensure_table_volume_exists(self, dataset_id: str) -> None: + def _ensure_table_volume_exists(self, dataset_id: str): """Ensure table volume exists for the given dataset_id.""" if self._config.volume_type != "table" or not dataset_id: return @@ -252,7 +252,7 @@ class ClickZettaVolumeStorage(BaseStorage): # Don't raise exception, let the operation continue # The table might exist but not be visible due to permissions - def save(self, filename: str, data: bytes) -> None: + def save(self, filename: str, data: bytes): """Save data to ClickZetta Volume. Args: diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py index c41344774..ef6b12fd5 100644 --- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -38,7 +38,7 @@ class FileMetadata: tags: Optional[dict[str, str]] = None parent_version: Optional[int] = None - def to_dict(self) -> dict: + def to_dict(self): """Convert to dictionary format""" data = asdict(self) data["created_at"] = self.created_at.isoformat() diff --git a/api/extensions/storage/clickzetta_volume/volume_permissions.py b/api/extensions/storage/clickzetta_volume/volume_permissions.py index d216790f1..243df92ef 100644 --- a/api/extensions/storage/clickzetta_volume/volume_permissions.py +++ b/api/extensions/storage/clickzetta_volume/volume_permissions.py @@ -623,7 +623,7 @@ class VolumePermissionError(Exception): def check_volume_permission( permission_manager: VolumePermissionManager, operation: str, dataset_id: Optional[str] = None -) -> None: +): """Permission check decorator function Args: diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py index 0ba35506d..21b82d79e 100644 --- a/api/extensions/storage/opendal_storage.py +++ b/api/extensions/storage/opendal_storage.py @@ -40,7 +40,7 @@ class OpenDALStorage(BaseStorage): self.op = self.op.layer(retry_layer) logger.debug("added retry layer to opendal operator") - def save(self, filename: str, data: bytes) -> None: + def save(self, filename: str, data: bytes): self.op.write(path=filename, bs=data) logger.debug("file %s saved", filename) diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 62e3bfa3b..9433b312c 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -403,7 +403,7 @@ class StorageKeyLoader: This loader is batched, the database query count is constant regardless of the input size. """ - def __init__(self, session: Session, tenant_id: str) -> None: + def __init__(self, session: Session, tenant_id: str): self._session = session self._tenant_id = tenant_id diff --git a/api/libs/email_i18n.py b/api/libs/email_i18n.py index b7c9f3ec6..3c039dff5 100644 --- a/api/libs/email_i18n.py +++ b/api/libs/email_i18n.py @@ -128,7 +128,7 @@ class FeatureBrandingService: class EmailSender(Protocol): """Protocol for email sending abstraction.""" - def send_email(self, to: str, subject: str, html_content: str) -> None: + def send_email(self, to: str, subject: str, html_content: str): """Send email with given parameters.""" ... @@ -136,7 +136,7 @@ class EmailSender(Protocol): class FlaskMailSender: """Flask-Mail based email sender.""" - def send_email(self, to: str, subject: str, html_content: str) -> None: + def send_email(self, to: str, subject: str, html_content: str): """Send email using Flask-Mail.""" if mail.is_inited(): mail.send(to=to, subject=subject, html=html_content) @@ -156,7 +156,7 @@ class EmailI18nService: renderer: EmailRenderer, branding_service: BrandingService, sender: EmailSender, - ) -> None: + ): self._config = config self._renderer = renderer self._branding_service = branding_service @@ -168,7 +168,7 @@ class EmailI18nService: language_code: str, to: str, template_context: Optional[dict[str, Any]] = None, - ) -> None: + ): """ Send internationalized email with branding support. @@ -192,7 +192,7 @@ class EmailI18nService: to: str, code: str, phase: str, - ) -> None: + ): """ Send change email notification with phase-specific handling. @@ -224,7 +224,7 @@ class EmailI18nService: to: str | list[str], subject: str, html_content: str, - ) -> None: + ): """ Send a raw email directly without template processing. diff --git a/api/libs/external_api.py b/api/libs/external_api.py index d5409c4b4..cee80f7f2 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -16,7 +16,7 @@ def http_status_message(code): return HTTP_STATUS_CODES.get(code, "") -def register_external_error_handlers(api: Api) -> None: +def register_external_error_handlers(api: Api): @api.errorhandler(HTTPException) def handle_http_exception(e: HTTPException): got_request_exception.send(current_app, exception=e) diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py index 9ab53b629..0c642041b 100644 --- a/api/libs/json_in_md_parser.py +++ b/api/libs/json_in_md_parser.py @@ -3,7 +3,7 @@ import json from core.llm_generator.output_parser.errors import OutputParserError -def parse_json_markdown(json_string: str) -> dict: +def parse_json_markdown(json_string: str): # Get json from the backticks/braces json_string = json_string.strip() starts = ["```json", "```", "``", "`", "{"] @@ -33,7 +33,7 @@ def parse_json_markdown(json_string: str) -> dict: return parsed -def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: +def parse_and_check_json_markdown(text: str, expected_keys: list[str]): try: json_obj = parse_json_markdown(text) except json.JSONDecodeError as e: diff --git a/api/libs/module_loading.py b/api/libs/module_loading.py index 616d072a1..9f7494343 100644 --- a/api/libs/module_loading.py +++ b/api/libs/module_loading.py @@ -7,10 +7,9 @@ https://github.com/django/django/blob/main/django/utils/module_loading.py import sys from importlib import import_module -from typing import Any -def cached_import(module_path: str, class_name: str) -> Any: +def cached_import(module_path: str, class_name: str): """ Import a module and return the named attribute/class from it, with caching. @@ -30,7 +29,7 @@ def cached_import(module_path: str, class_name: str) -> Any: return getattr(module, class_name) -def import_string(dotted_path: str) -> Any: +def import_string(dotted_path: str): """ Import a dotted module path and return the attribute/class designated by the last name in the path. Raise ImportError if the import failed. diff --git a/api/models/account.py b/api/models/account.py index 6db1381df..4fec41c4e 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -225,7 +225,7 @@ class Tenant(Base): ) @property - def custom_config_dict(self) -> dict: + def custom_config_dict(self): return json.loads(self.custom_config) if self.custom_config else {} @custom_config_dict.setter diff --git a/api/models/model.py b/api/models/model.py index aa1a87e3b..fbebdc817 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -162,7 +162,7 @@ class App(Base): return str(self.mode) @property - def deleted_tools(self) -> list: + def deleted_tools(self): from core.tools.tool_manager import ToolManager from services.plugin.plugin_service import PluginService @@ -339,15 +339,15 @@ class AppModelConfig(Base): return app @property - def model_dict(self) -> dict: + def model_dict(self): return json.loads(self.model) if self.model else {} @property - def suggested_questions_list(self) -> list: + def suggested_questions_list(self): return json.loads(self.suggested_questions) if self.suggested_questions else [] @property - def suggested_questions_after_answer_dict(self) -> dict: + def suggested_questions_after_answer_dict(self): return ( json.loads(self.suggested_questions_after_answer) if self.suggested_questions_after_answer @@ -355,19 +355,19 @@ class AppModelConfig(Base): ) @property - def speech_to_text_dict(self) -> dict: + def speech_to_text_dict(self): return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False} @property - def text_to_speech_dict(self) -> dict: + def text_to_speech_dict(self): return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False} @property - def retriever_resource_dict(self) -> dict: + def retriever_resource_dict(self): return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} @property - def annotation_reply_dict(self) -> dict: + def annotation_reply_dict(self): annotation_setting = ( db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first() ) @@ -390,11 +390,11 @@ class AppModelConfig(Base): return {"enabled": False} @property - def more_like_this_dict(self) -> dict: + def more_like_this_dict(self): return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} @property - def sensitive_word_avoidance_dict(self) -> dict: + def sensitive_word_avoidance_dict(self): return ( json.loads(self.sensitive_word_avoidance) if self.sensitive_word_avoidance @@ -410,7 +410,7 @@ class AppModelConfig(Base): return json.loads(self.user_input_form) if self.user_input_form else [] @property - def agent_mode_dict(self) -> dict: + def agent_mode_dict(self): return ( json.loads(self.agent_mode) if self.agent_mode @@ -418,15 +418,15 @@ class AppModelConfig(Base): ) @property - def chat_prompt_config_dict(self) -> dict: + def chat_prompt_config_dict(self): return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {} @property - def completion_prompt_config_dict(self) -> dict: + def completion_prompt_config_dict(self): return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {} @property - def dataset_configs_dict(self) -> dict: + def dataset_configs_dict(self): if self.dataset_configs: dataset_configs: dict = json.loads(self.dataset_configs) if "retrieval_model" not in dataset_configs: @@ -438,7 +438,7 @@ class AppModelConfig(Base): } @property - def file_upload_dict(self) -> dict: + def file_upload_dict(self): return ( json.loads(self.file_upload) if self.file_upload @@ -452,7 +452,7 @@ class AppModelConfig(Base): } ) - def to_dict(self) -> dict: + def to_dict(self): return { "opening_statement": self.opening_statement, "suggested_questions": self.suggested_questions_list, @@ -1087,7 +1087,7 @@ class Message(Base): return self.override_model_configs is not None @property - def message_metadata_dict(self) -> dict: + def message_metadata_dict(self): return json.loads(self.message_metadata) if self.message_metadata else {} @property @@ -1176,7 +1176,7 @@ class Message(Base): return None - def to_dict(self) -> dict: + def to_dict(self): return { "id": self.id, "app_id": self.app_id, @@ -1689,7 +1689,7 @@ class MessageAgentThought(Base): created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) @property - def files(self) -> list: + def files(self): if self.message_files: return cast(list[Any], json.loads(self.message_files)) else: @@ -1700,7 +1700,7 @@ class MessageAgentThought(Base): return self.tool.split(";") if self.tool else [] @property - def tool_labels(self) -> dict: + def tool_labels(self): try: if self.tool_labels_str: return cast(dict, json.loads(self.tool_labels_str)) @@ -1710,7 +1710,7 @@ class MessageAgentThought(Base): return {} @property - def tool_meta(self) -> dict: + def tool_meta(self): try: if self.tool_meta_str: return cast(dict, json.loads(self.tool_meta_str)) @@ -1720,7 +1720,7 @@ class MessageAgentThought(Base): return {} @property - def tool_inputs_dict(self) -> dict: + def tool_inputs_dict(self): tools = self.tools try: if self.tool_input: diff --git a/api/models/tools.py b/api/models/tools.py index 9c460e9bf..8755570ee 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,6 +1,6 @@ import json from datetime import datetime -from typing import Any, Optional, cast +from typing import Optional, cast from urllib.parse import urlparse import sqlalchemy as sa @@ -54,7 +54,7 @@ class ToolOAuthTenantClient(Base): encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False) @property - def oauth_params(self) -> dict: + def oauth_params(self): return cast(dict, json.loads(self.encrypted_oauth_params or "{}")) @@ -96,7 +96,7 @@ class BuiltinToolProvider(Base): expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1")) @property - def credentials(self) -> dict: + def credentials(self): return cast(dict, json.loads(self.encrypted_credentials)) @@ -146,7 +146,7 @@ class ApiToolProvider(Base): return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)] @property - def credentials(self) -> dict: + def credentials(self): return dict(json.loads(self.credentials_str)) @property @@ -289,7 +289,7 @@ class MCPToolProvider(Base): return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() @property - def credentials(self) -> dict: + def credentials(self): try: return cast(dict, json.loads(self.encrypted_credentials)) or {} except Exception: @@ -327,7 +327,7 @@ class MCPToolProvider(Base): return mask_url(self.decrypted_server_url) @property - def decrypted_credentials(self) -> dict: + def decrypted_credentials(self): from core.helper.provider_cache import NoOpProviderCredentialCache from core.tools.mcp_tool.provider import MCPToolProviderController from core.tools.utils.encryption import create_provider_encrypter @@ -408,7 +408,7 @@ class ToolConversationVariables(Base): updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property - def variables(self) -> Any: + def variables(self): return json.loads(self.variables_str) diff --git a/api/models/workflow.py b/api/models/workflow.py index 28bf683fb..23f18929d 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -282,14 +282,14 @@ class Workflow(Base): return self._features @features.setter - def features(self, value: str) -> None: + def features(self, value: str): self._features = value @property def features_dict(self) -> dict[str, Any]: return json.loads(self.features) if self.features else {} - def user_input_form(self, to_old_structure: bool = False) -> list: + def user_input_form(self, to_old_structure: bool = False): # get start node from graph if not self.graph: return [] @@ -439,7 +439,7 @@ class Workflow(Base): return results @conversation_variables.setter - def conversation_variables(self, value: Sequence[Variable]) -> None: + def conversation_variables(self, value: Sequence[Variable]): self._conversation_variables = json.dumps( {var.name: var.model_dump() for var in value}, ensure_ascii=False, @@ -892,7 +892,7 @@ class ConversationVariable(Base): DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) - def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None: + def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str): self.id = id self.app_id = app_id self.conversation_id = conversation_id @@ -1073,7 +1073,7 @@ class WorkflowDraftVariable(Base): return self.build_segment_with_type(self.value_type, value) @staticmethod - def rebuild_file_types(value: Any) -> Any: + def rebuild_file_types(value: Any): # NOTE(QuantumGhost): Temporary workaround for structured data handling. # By this point, `output` has been converted to dict by # `WorkflowEntry.handle_special_values`, so we need to diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index e69ea9b7c..6294846f5 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -46,7 +46,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): session_maker: SQLAlchemy sessionmaker instance for database connections """ - def __init__(self, session_maker: sessionmaker[Session]) -> None: + def __init__(self, session_maker: sessionmaker[Session]): """ Initialize the repository with a sessionmaker. diff --git a/api/services/account_service.py b/api/services/account_service.py index 660c80ebf..a76792f88 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -105,14 +105,14 @@ class AccountService: return f"{ACCOUNT_REFRESH_TOKEN_PREFIX}{account_id}" @staticmethod - def _store_refresh_token(refresh_token: str, account_id: str) -> None: + def _store_refresh_token(refresh_token: str, account_id: str): redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id) redis_client.setex( AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token ) @staticmethod - def _delete_refresh_token(refresh_token: str, account_id: str) -> None: + def _delete_refresh_token(refresh_token: str, account_id: str): redis_client.delete(AccountService._get_refresh_token_key(refresh_token)) redis_client.delete(AccountService._get_account_refresh_token_key(account_id)) @@ -312,12 +312,12 @@ class AccountService: return True @staticmethod - def delete_account(account: Account) -> None: + def delete_account(account: Account): """Delete account. This method only adds a task to the queue for deletion.""" delete_account_task.delay(account.id) @staticmethod - def link_account_integrate(provider: str, open_id: str, account: Account) -> None: + def link_account_integrate(provider: str, open_id: str, account: Account): """Link account integrate""" try: # Query whether there is an existing binding record for the same provider @@ -344,7 +344,7 @@ class AccountService: raise LinkAccountIntegrateError("Failed to link account.") from e @staticmethod - def close_account(account: Account) -> None: + def close_account(account: Account): """Close account""" account.status = AccountStatus.CLOSED.value db.session.commit() @@ -374,7 +374,7 @@ class AccountService: return account @staticmethod - def update_login_info(account: Account, *, ip_address: str) -> None: + def update_login_info(account: Account, *, ip_address: str): """Update last login time and ip""" account.last_login_at = naive_utc_now() account.last_login_ip = ip_address @@ -398,7 +398,7 @@ class AccountService: return TokenPair(access_token=access_token, refresh_token=refresh_token) @staticmethod - def logout(*, account: Account) -> None: + def logout(*, account: Account): refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id)) if refresh_token: AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id) @@ -705,7 +705,7 @@ class AccountService: @staticmethod @redis_fallback(default_return=None) - def add_login_error_rate_limit(email: str) -> None: + def add_login_error_rate_limit(email: str): key = f"login_error_rate_limit:{email}" count = redis_client.get(key) if count is None: @@ -734,7 +734,7 @@ class AccountService: @staticmethod @redis_fallback(default_return=None) - def add_forgot_password_error_rate_limit(email: str) -> None: + def add_forgot_password_error_rate_limit(email: str): key = f"forgot_password_error_rate_limit:{email}" count = redis_client.get(key) if count is None: @@ -763,7 +763,7 @@ class AccountService: @staticmethod @redis_fallback(default_return=None) - def add_change_email_error_rate_limit(email: str) -> None: + def add_change_email_error_rate_limit(email: str): key = f"change_email_error_rate_limit:{email}" count = redis_client.get(key) if count is None: @@ -791,7 +791,7 @@ class AccountService: @staticmethod @redis_fallback(default_return=None) - def add_owner_transfer_error_rate_limit(email: str) -> None: + def add_owner_transfer_error_rate_limit(email: str): key = f"owner_transfer_error_rate_limit:{email}" count = redis_client.get(key) if count is None: @@ -970,7 +970,7 @@ class TenantService: return tenant @staticmethod - def switch_tenant(account: Account, tenant_id: Optional[str] = None) -> None: + def switch_tenant(account: Account, tenant_id: Optional[str] = None): """Switch the current workspace for the account""" # Ensure tenant_id is provided @@ -1067,7 +1067,7 @@ class TenantService: return cast(int, db.session.query(func.count(Tenant.id)).scalar()) @staticmethod - def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None: + def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str): """Check member permission""" perms = { "add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], @@ -1087,7 +1087,7 @@ class TenantService: raise NoPermissionError(f"No permission to {action} member.") @staticmethod - def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None: + def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account): """Remove member from tenant""" if operator.id == account.id: raise CannotOperateSelfError("Cannot operate self.") @@ -1102,7 +1102,7 @@ class TenantService: db.session.commit() @staticmethod - def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None: + def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account): """Update member role""" TenantService.check_member_permission(tenant, operator, member, "update") @@ -1129,7 +1129,7 @@ class TenantService: db.session.commit() @staticmethod - def get_custom_config(tenant_id: str) -> dict: + def get_custom_config(tenant_id: str): tenant = db.get_or_404(Tenant, tenant_id) return tenant.custom_config_dict @@ -1150,7 +1150,7 @@ class RegisterService: return f"member_invite:token:{token}" @classmethod - def setup(cls, email: str, name: str, password: str, ip_address: str) -> None: + def setup(cls, email: str, name: str, password: str, ip_address: str): """ Setup dify diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index 6dc1affa1..6f0ab2546 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -17,7 +17,7 @@ from models.model import AppMode class AdvancedPromptTemplateService: @classmethod - def get_prompt(cls, args: dict) -> dict: + def get_prompt(cls, args: dict): app_mode = args["app_mode"] model_mode = args["model_mode"] model_name = args["model_name"] @@ -29,7 +29,7 @@ class AdvancedPromptTemplateService: return cls.get_common_prompt(app_mode, model_mode, has_context) @classmethod - def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict: + def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str): context_prompt = copy.deepcopy(CONTEXT) if app_mode == AppMode.CHAT.value: @@ -52,7 +52,7 @@ class AdvancedPromptTemplateService: return {} @classmethod - def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: + def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str): if has_context == "true": prompt_template["completion_prompt_config"]["prompt"]["text"] = ( context + prompt_template["completion_prompt_config"]["prompt"]["text"] @@ -61,7 +61,7 @@ class AdvancedPromptTemplateService: return prompt_template @classmethod - def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: + def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str): if has_context == "true": prompt_template["chat_prompt_config"]["prompt"][0]["text"] = ( context + prompt_template["chat_prompt_config"]["prompt"][0]["text"] @@ -70,7 +70,7 @@ class AdvancedPromptTemplateService: return prompt_template @classmethod - def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict: + def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str): baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT) if app_mode == AppMode.CHAT.value: diff --git a/api/services/agent_service.py b/api/services/agent_service.py index 7c6df2428..72833b9d6 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -16,7 +16,7 @@ from models.model import App, Conversation, EndUser, Message, MessageAgentThough class AgentService: @classmethod - def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -> dict: + def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str): """ Service to get agent logs """ diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 45656e790..24567cc34 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -73,7 +73,7 @@ class AppAnnotationService: return annotation @classmethod - def enable_app_annotation(cls, args: dict, app_id: str) -> dict: + def enable_app_annotation(cls, args: dict, app_id: str): enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}" cache_result = redis_client.get(enable_app_annotation_key) if cache_result is not None: @@ -96,7 +96,7 @@ class AppAnnotationService: return {"job_id": job_id, "job_status": "waiting"} @classmethod - def disable_app_annotation(cls, app_id: str) -> dict: + def disable_app_annotation(cls, app_id: str): disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}" cache_result = redis_client.get(disable_app_annotation_key) if cache_result is not None: @@ -315,7 +315,7 @@ class AppAnnotationService: return {"deleted_count": deleted_count} @classmethod - def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict: + def batch_import_app_annotations(cls, app_id, file: FileStorage): # get app info app = ( db.session.query(App) @@ -490,7 +490,7 @@ class AppAnnotationService: } @classmethod - def clear_all_annotations(cls, app_id: str) -> dict: + def clear_all_annotations(cls, app_id: str): app = ( db.session.query(App) .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py index 2f28eff16..3a0ed41be 100644 --- a/api/services/api_based_extension_service.py +++ b/api/services/api_based_extension_service.py @@ -30,7 +30,7 @@ class APIBasedExtensionService: return extension_data @staticmethod - def delete(extension_data: APIBasedExtension) -> None: + def delete(extension_data: APIBasedExtension): db.session.delete(extension_data) db.session.commit() @@ -51,7 +51,7 @@ class APIBasedExtensionService: return extension @classmethod - def _validation(cls, extension_data: APIBasedExtension) -> None: + def _validation(cls, extension_data: APIBasedExtension): # name if not extension_data.name: raise ValueError("name must not be empty") @@ -95,7 +95,7 @@ class APIBasedExtensionService: cls._ping_connection(extension_data) @staticmethod - def _ping_connection(extension_data: APIBasedExtension) -> None: + def _ping_connection(extension_data: APIBasedExtension): try: client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key) resp = client.request(point=APIBasedExtensionPoint.PING, params={}) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 2663cb380..2344be0aa 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -566,7 +566,7 @@ class AppDslService: @classmethod def _append_workflow_export_data( cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: Optional[str] = None - ) -> None: + ): """ Append workflow export data :param export_data: export data @@ -608,7 +608,7 @@ class AppDslService: ] @classmethod - def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None: + def _append_model_config_export_data(cls, export_data: dict, app_model: App): """ Append model config export data :param export_data: export data diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index a1ad27105..6f54f9073 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -6,7 +6,7 @@ from models.model import AppMode class AppModelConfigService: @classmethod - def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict: + def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode): if app_mode == AppMode.CHAT: return ChatAppConfigManager.config_validate(tenant_id, config) elif app_mode == AppMode.AGENT_CHAT: diff --git a/api/services/app_service.py b/api/services/app_service.py index 1df926cc2..4502fa929 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -316,7 +316,7 @@ class AppService: return app - def delete_app(self, app: App) -> None: + def delete_app(self, app: App): """ Delete app :param app: App instance @@ -331,7 +331,7 @@ class AppService: # Trigger asynchronous deletion of app and related data remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id) - def get_app_meta(self, app_model: App) -> dict: + def get_app_meta(self, app_model: App): """ Get app meta info :param app_model: app model diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py index 996e9187f..f6e960b41 100644 --- a/api/services/auth/api_key_auth_service.py +++ b/api/services/auth/api_key_auth_service.py @@ -8,7 +8,7 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory class ApiKeyAuthService: @staticmethod - def get_provider_auth_list(tenant_id: str) -> list: + def get_provider_auth_list(tenant_id: str): data_source_api_key_bindings = ( db.session.query(DataSourceApiKeyAuthBinding) .where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)) diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index de00e7463..2f1b63664 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -34,7 +34,7 @@ logger = logging.getLogger(__name__) class ClearFreePlanTenantExpiredLogs: @classmethod - def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]) -> None: + def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]): """ Clean up message-related tables to avoid data redundancy. This method cleans up tables that have foreign key relationships with Message. @@ -353,7 +353,7 @@ class ClearFreePlanTenantExpiredLogs: thread_pool = ThreadPoolExecutor(max_workers=10) - def process_tenant(flask_app: Flask, tenant_id: str) -> None: + def process_tenant(flask_app: Flask, tenant_id: str): try: if ( not dify_config.BILLING_ENABLED diff --git a/api/services/code_based_extension_service.py b/api/services/code_based_extension_service.py index f7597b7f1..7c893463d 100644 --- a/api/services/code_based_extension_service.py +++ b/api/services/code_based_extension_service.py @@ -3,7 +3,7 @@ from extensions.ext_code_based_extension import code_based_extension class CodeBasedExtensionService: @staticmethod - def get_code_based_extension(module: str) -> list[dict]: + def get_code_based_extension(module: str): module_extensions = code_based_extension.module_extensions(module) return [ { diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index ac603d3cc..d017ce54a 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -250,7 +250,7 @@ class ConversationService: variable_id: str, user: Optional[Union[Account, EndUser]], new_value: Any, - ) -> dict: + ): """ Update a conversation variable's value. diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 4b202001d..e0885f325 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -719,7 +719,7 @@ class DatasetService: ) @staticmethod - def get_dataset_auto_disable_logs(dataset_id: str) -> dict: + def get_dataset_auto_disable_logs(dataset_id: str): features = FeatureService.get_features(current_user.current_tenant_id) if not features.billing.enabled or features.billing.subscription.plan == "sandbox": return { diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 1fe259dd4..647052d73 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -83,7 +83,7 @@ class ProviderResponse(BaseModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def __init__(self, **data) -> None: + def __init__(self, **data): super().__init__(**data) url_prefix = ( @@ -113,7 +113,7 @@ class ProviderWithModelsResponse(BaseModel): status: CustomConfigurationStatus models: list[ProviderModelWithStatusEntity] - def __init__(self, **data) -> None: + def __init__(self, **data): super().__init__(**data) url_prefix = ( @@ -137,7 +137,7 @@ class SimpleProviderEntityResponse(SimpleProviderEntity): tenant_id: str - def __init__(self, **data) -> None: + def __init__(self, **data): super().__init__(**data) url_prefix = ( @@ -174,7 +174,7 @@ class ModelWithProviderEntityResponse(ProviderModelWithStatusEntity): provider: SimpleProviderEntityResponse - def __init__(self, tenant_id: str, model: ModelWithProviderEntity) -> None: + def __init__(self, tenant_id: str, model: ModelWithProviderEntity): dump_model = model.model_dump() dump_model["provider"]["tenant_id"] = tenant_id super().__init__(**dump_model) diff --git a/api/services/errors/llm.py b/api/services/errors/llm.py index e4fac6f74..ca4c9a611 100644 --- a/api/services/errors/llm.py +++ b/api/services/errors/llm.py @@ -6,7 +6,7 @@ class InvokeError(Exception): description: Optional[str] = None - def __init__(self, description: Optional[str] = None) -> None: + def __init__(self, description: Optional[str] = None): self.description = description def __str__(self): diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 077571ffb..783d6c242 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -277,7 +277,7 @@ class ExternalDatasetService: query: str, external_retrieval_parameters: dict, metadata_condition: Optional[MetadataCondition] = None, - ) -> list: + ): external_knowledge_binding = ( db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first() ) diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index bce28da03..00ec3babf 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -33,7 +33,7 @@ class HitTestingService: retrieval_model: Any, # FIXME drop this any external_retrieval_model: dict, limit: int = 10, - ) -> dict: + ): start = time.perf_counter() # get retrieval model , if the model is not setting , using default @@ -98,7 +98,7 @@ class HitTestingService: account: Account, external_retrieval_model: dict, metadata_filtering_conditions: dict, - ) -> dict: + ): if dataset.provider != "external": return { "query": {"content": query}, diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 17696f5cd..c638087f6 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -25,10 +25,10 @@ logger = logging.getLogger(__name__) class ModelLoadBalancingService: - def __init__(self) -> None: + def __init__(self): self.provider_manager = ProviderManager() - def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: + def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str): """ enable model load balancing. @@ -49,7 +49,7 @@ class ModelLoadBalancingService: # Enable model load balancing provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) - def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: + def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str): """ disable model load balancing. @@ -295,7 +295,7 @@ class ModelLoadBalancingService: def update_load_balancing_configs( self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict], config_from: str - ) -> None: + ): """ Update load balancing configurations. :param tenant_id: workspace id @@ -478,7 +478,7 @@ class ModelLoadBalancingService: model_type: str, credentials: dict, config_id: Optional[str] = None, - ) -> None: + ): """ Validate load balancing credentials. :param tenant_id: workspace id @@ -537,7 +537,7 @@ class ModelLoadBalancingService: credentials: dict, load_balancing_model_config: Optional[LoadBalancingModelConfig] = None, validate: bool = True, - ) -> dict: + ): """ Validate custom credentials. :param tenant_id: workspace id @@ -605,7 +605,7 @@ class ModelLoadBalancingService: else: raise ValueError("No credential schema found") - def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None: + def _clear_credentials_cache(self, tenant_id: str, config_id: str): """ Clear credentials cache. :param tenant_id: workspace id diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 69c7e4cf5..510b1f1fe 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -26,7 +26,7 @@ class ModelProviderService: Model Provider Service """ - def __init__(self) -> None: + def __init__(self): self.provider_manager = ProviderManager() def _get_provider_configuration(self, tenant_id: str, provider: str): @@ -142,7 +142,7 @@ class ModelProviderService: provider_configuration = self._get_provider_configuration(tenant_id, provider) return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore - def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None: + def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict): """ validate provider credentials before saving. @@ -193,7 +193,7 @@ class ModelProviderService: credential_name=credential_name, ) - def remove_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None: + def remove_provider_credential(self, tenant_id: str, provider: str, credential_id: str): """ remove a saved provider credential (by credential_id). :param tenant_id: workspace id @@ -204,7 +204,7 @@ class ModelProviderService: provider_configuration = self._get_provider_configuration(tenant_id, provider) provider_configuration.delete_provider_credential(credential_id=credential_id) - def switch_active_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None: + def switch_active_provider_credential(self, tenant_id: str, provider: str, credential_id: str): """ :param tenant_id: workspace id :param provider: provider name @@ -232,9 +232,7 @@ class ModelProviderService: model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id ) - def validate_model_credentials( - self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict - ) -> None: + def validate_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict): """ validate model credentials. @@ -303,9 +301,7 @@ class ModelProviderService: credential_name=credential_name, ) - def remove_model_credential( - self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str - ) -> None: + def remove_model_credential(self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str): """ remove model credentials. @@ -323,7 +319,7 @@ class ModelProviderService: def switch_active_custom_model_credential( self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str - ) -> None: + ): """ switch model credentials. @@ -341,7 +337,7 @@ class ModelProviderService: def add_model_credential_to_model_list( self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str - ) -> None: + ): """ add model credentials to model list. @@ -357,7 +353,7 @@ class ModelProviderService: model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id ) - def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str) -> None: + def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str): """ remove model credentials. @@ -485,7 +481,7 @@ class ModelProviderService: logger.debug("get_default_model_of_model_type error: %s", e) return None - def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None: + def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str): """ update default model of model type. @@ -517,7 +513,7 @@ class ModelProviderService: return byte_data, mime_type - def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None: + def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str): """ switch preferred provider. @@ -534,7 +530,7 @@ class ModelProviderService: # Switch preferred provider type provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum) - def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: + def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str): """ enable model. @@ -547,7 +543,7 @@ class ModelProviderService: provider_configuration = self._get_provider_configuration(tenant_id, provider) provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type)) - def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: + def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str): """ disable model. diff --git a/api/services/plugin/data_migration.py b/api/services/plugin/data_migration.py index 39585d783..71a7b34a7 100644 --- a/api/services/plugin/data_migration.py +++ b/api/services/plugin/data_migration.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) class PluginDataMigration: @classmethod - def migrate(cls) -> None: + def migrate(cls): cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table cls.migrate_db_records("provider_models", "provider_name", ModelProviderID) cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID) @@ -26,7 +26,7 @@ class PluginDataMigration: cls.migrate_db_records("tool_builtin_providers", "provider", ToolProviderID) @classmethod - def migrate_datasets(cls) -> None: + def migrate_datasets(cls): table_name = "datasets" provider_column_name = "embedding_model_provider" @@ -126,9 +126,7 @@ limit 1000""" ) @classmethod - def migrate_db_records( - cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID] - ) -> None: + def migrate_db_records(cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]): click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white")) processed_count = 0 diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index 221069b2b..8dbf117fd 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -33,7 +33,7 @@ excluded_providers = ["time", "audio", "code", "webscraper"] class PluginMigration: @classmethod - def extract_plugins(cls, filepath: str, workers: int) -> None: + def extract_plugins(cls, filepath: str, workers: int): """ Migrate plugin. """ @@ -55,7 +55,7 @@ class PluginMigration: thread_pool = ThreadPoolExecutor(max_workers=workers) - def process_tenant(flask_app: Flask, tenant_id: str) -> None: + def process_tenant(flask_app: Flask, tenant_id: str): with flask_app.app_context(): nonlocal handled_tenant_count try: @@ -291,7 +291,7 @@ class PluginMigration: return plugin_manifest[0].latest_package_identifier @classmethod - def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str) -> None: + def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str): """ Extract unique plugins. """ @@ -328,7 +328,7 @@ class PluginMigration: return {"plugins": plugins, "plugin_not_exist": plugin_not_exist} @classmethod - def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100) -> None: + def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100): """ Install plugins. """ @@ -348,7 +348,7 @@ class PluginMigration: if response.get("failed"): plugin_install_failed.extend(response.get("failed", [])) - def install(tenant_id: str, plugin_ids: list[str]) -> None: + def install(tenant_id: str, plugin_ids: list[str]): logger.info("Installing %s plugins for tenant %s", len(plugin_ids), tenant_id) # fetch plugin already installed installed_plugins = manager.list_plugins(tenant_id) diff --git a/api/services/recommend_app/buildin/buildin_retrieval.py b/api/services/recommend_app/buildin/buildin_retrieval.py index 523aebeed..df9e01e27 100644 --- a/api/services/recommend_app/buildin/buildin_retrieval.py +++ b/api/services/recommend_app/buildin/buildin_retrieval.py @@ -19,7 +19,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase): def get_type(self) -> str: return RecommendAppType.BUILDIN - def get_recommended_apps_and_categories(self, language: str) -> dict: + def get_recommended_apps_and_categories(self, language: str): result = self.fetch_recommended_apps_from_builtin(language) return result @@ -28,7 +28,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase): return result @classmethod - def _get_builtin_data(cls) -> dict: + def _get_builtin_data(cls): """ Get builtin data. :return: @@ -44,7 +44,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase): return cls.builtin_data or {} @classmethod - def fetch_recommended_apps_from_builtin(cls, language: str) -> dict: + def fetch_recommended_apps_from_builtin(cls, language: str): """ Fetch recommended apps from builtin. :param language: language diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py index b97d13d01..e19f53f12 100644 --- a/api/services/recommend_app/database/database_retrieval.py +++ b/api/services/recommend_app/database/database_retrieval.py @@ -13,7 +13,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): Retrieval recommended app from database """ - def get_recommended_apps_and_categories(self, language: str) -> dict: + def get_recommended_apps_and_categories(self, language: str): result = self.fetch_recommended_apps_from_db(language) return result @@ -25,7 +25,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): return RecommendAppType.DATABASE @classmethod - def fetch_recommended_apps_from_db(cls, language: str) -> dict: + def fetch_recommended_apps_from_db(cls, language: str): """ Fetch recommended apps from db. :param language: language diff --git a/api/services/recommend_app/recommend_app_base.py b/api/services/recommend_app/recommend_app_base.py index 00c037710..1f62fbf9d 100644 --- a/api/services/recommend_app/recommend_app_base.py +++ b/api/services/recommend_app/recommend_app_base.py @@ -5,7 +5,7 @@ class RecommendAppRetrievalBase(ABC): """Interface for recommend app retrieval.""" @abstractmethod - def get_recommended_apps_and_categories(self, language: str) -> dict: + def get_recommended_apps_and_categories(self, language: str): raise NotImplementedError @abstractmethod diff --git a/api/services/recommend_app/remote/remote_retrieval.py b/api/services/recommend_app/remote/remote_retrieval.py index 85f3a0282..1e5928742 100644 --- a/api/services/recommend_app/remote/remote_retrieval.py +++ b/api/services/recommend_app/remote/remote_retrieval.py @@ -24,7 +24,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin(app_id) return result - def get_recommended_apps_and_categories(self, language: str) -> dict: + def get_recommended_apps_and_categories(self, language: str): try: result = self.fetch_recommended_apps_from_dify_official(language) except Exception as e: @@ -51,7 +51,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): return data @classmethod - def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: + def fetch_recommended_apps_from_dify_official(cls, language: str): """ Fetch recommended apps from dify official. :param language: language diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 2aebe6b6b..d9c1b51fa 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -6,7 +6,7 @@ from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFa class RecommendedAppService: @classmethod - def get_recommended_apps_and_categories(cls, language: str) -> dict: + def get_recommended_apps_and_categories(cls, language: str): """ Get recommended apps and categories. :param language: language diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 2e5e96214..a16bdb46c 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -12,7 +12,7 @@ from models.model import App, Tag, TagBinding class TagService: @staticmethod - def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None) -> list: + def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None): query = ( db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count")) .outerjoin(TagBinding, Tag.id == TagBinding.tag_id) @@ -25,7 +25,7 @@ class TagService: return results @staticmethod - def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list: + def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list): # Check if tag_ids is not empty to avoid WHERE false condition if not tag_ids or len(tag_ids) == 0: return [] @@ -51,7 +51,7 @@ class TagService: return results @staticmethod - def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str) -> list: + def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str): if not tag_type or not tag_name: return [] tags = ( @@ -64,7 +64,7 @@ class TagService: return tags @staticmethod - def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list: + def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str): tags = ( db.session.query(Tag) .join(TagBinding, Tag.id == TagBinding.tag_id) diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 75da5e5ea..2f8a91ed8 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -37,7 +37,7 @@ class WorkflowToolManageService: parameters: list[Mapping[str, Any]], privacy_policy: str = "", labels: list[str] | None = None, - ) -> dict: + ): WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) # check if the name is unique @@ -103,7 +103,7 @@ class WorkflowToolManageService: parameters: list[Mapping[str, Any]], privacy_policy: str = "", labels: list[str] | None = None, - ) -> dict: + ): """ Update a workflow tool. :param user_id: the user id @@ -217,7 +217,7 @@ class WorkflowToolManageService: return result @classmethod - def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict: + def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str): """ Delete a workflow tool. :param user_id: the user id @@ -233,7 +233,7 @@ class WorkflowToolManageService: return {"result": "success"} @classmethod - def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict: + def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str): """ Get a workflow tool. :param user_id: the user id @@ -249,7 +249,7 @@ class WorkflowToolManageService: return cls._get_workflow_tool(tenant_id, db_tool) @classmethod - def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict: + def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str): """ Get a workflow tool. :param user_id: the user id @@ -265,7 +265,7 @@ class WorkflowToolManageService: return cls._get_workflow_tool(tenant_id, db_tool) @classmethod - def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None) -> dict: + def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None): """ Get a workflow tool. :db_tool: the database tool diff --git a/api/services/website_service.py b/api/services/website_service.py index 991b66973..131b96db1 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -132,7 +132,7 @@ class WebsiteService: return encrypter.decrypt_token(tenant_id=tenant_id, token=api_key) @classmethod - def document_create_args_validate(cls, args: dict) -> None: + def document_create_args_validate(cls, args: dict): """Validate arguments for document creation.""" try: WebsiteCrawlApiRequest.from_args(args) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 00b02f809..2994856b5 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -217,7 +217,7 @@ class WorkflowConverter: return app_config - def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict: + def _convert_to_start_node(self, variables: list[VariableEntity]): """ Convert to Start Node :param variables: list of variables @@ -384,7 +384,7 @@ class WorkflowConverter: prompt_template: PromptTemplateEntity, file_upload: Optional[FileUploadConfig] = None, external_data_variable_node_mapping: dict[str, str] | None = None, - ) -> dict: + ): """ Convert to LLM Node :param original_app_mode: original app mode @@ -550,7 +550,7 @@ class WorkflowConverter: return template - def _convert_to_end_node(self) -> dict: + def _convert_to_end_node(self): """ Convert to End Node :return: @@ -566,7 +566,7 @@ class WorkflowConverter: }, } - def _convert_to_answer_node(self) -> dict: + def _convert_to_answer_node(self): """ Convert to Answer Node :return: @@ -578,7 +578,7 @@ class WorkflowConverter: "data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"}, } - def _create_edge(self, source: str, target: str) -> dict: + def _create_edge(self, source: str, target: str): """ Create Edge :param source: source node id @@ -587,7 +587,7 @@ class WorkflowConverter: """ return {"id": f"{source}-{target}", "source": source, "target": target} - def _append_node(self, graph: dict, node: dict) -> dict: + def _append_node(self, graph: dict, node: dict): """ Append Node to Graph diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 6eabf0301..eda55d31d 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -23,7 +23,7 @@ class WorkflowAppService: limit: int = 20, created_by_end_user_session_id: str | None = None, created_by_account: str | None = None, - ) -> dict: + ): """ Get paginate workflow app logs using SQLAlchemy 2.0 style :param session: SQLAlchemy session diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index b3b581093..ae5f0a998 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -67,7 +67,7 @@ class DraftVarLoader(VariableLoader): app_id: str, tenant_id: str, fallback_variables: Sequence[Variable] | None = None, - ) -> None: + ): self._engine = engine self._app_id = app_id self._tenant_id = tenant_id @@ -117,7 +117,7 @@ class DraftVarLoader(VariableLoader): class WorkflowDraftVariableService: _session: Session - def __init__(self, session: Session) -> None: + def __init__(self, session: Session): """ Initialize the WorkflowDraftVariableService with a SQLAlchemy session. @@ -438,7 +438,7 @@ def _batch_upsert_draft_variable( session: Session, draft_vars: Sequence[WorkflowDraftVariable], policy: _UpsertPolicy = _UpsertPolicy.OVERWRITE, -) -> None: +): if not draft_vars: return None # Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons: diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 3f54f6624..350e52e43 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -591,7 +591,7 @@ class WorkflowService: return new_app - def validate_features_structure(self, app_model: App, features: dict) -> dict: + def validate_features_structure(self, app_model: App, features: dict): if app_model.mode == AppMode.ADVANCED_CHAT.value: return AdvancedChatAppConfigManager.config_validate( tenant_id=app_model.tenant_id, config=features, only_structure_validate=True diff --git a/api/tasks/delete_conversation_task.py b/api/tasks/delete_conversation_task.py index dc2751a65..756b67c93 100644 --- a/api/tasks/delete_conversation_task.py +++ b/api/tasks/delete_conversation_task.py @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="conversation") -def delete_conversation_related_data(conversation_id: str) -> None: +def delete_conversation_related_data(conversation_id: str): """ Delete related data conversation in correct order from datatbase to respect foreign key constraints diff --git a/api/tasks/mail_account_deletion_task.py b/api/tasks/mail_account_deletion_task.py index 41e8bc932..ae42dff90 100644 --- a/api/tasks/mail_account_deletion_task.py +++ b/api/tasks/mail_account_deletion_task.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="mail") -def send_deletion_success_task(to: str, language: str = "en-US") -> None: +def send_deletion_success_task(to: str, language: str = "en-US"): """ Send account deletion success email with internationalization support. @@ -46,7 +46,7 @@ def send_deletion_success_task(to: str, language: str = "en-US") -> None: @shared_task(queue="mail") -def send_account_deletion_verification_code(to: str, code: str, language: str = "en-US") -> None: +def send_account_deletion_verification_code(to: str, code: str, language: str = "en-US"): """ Send account deletion verification code email with internationalization support. diff --git a/api/tasks/mail_change_mail_task.py b/api/tasks/mail_change_mail_task.py index c090a8492..a974e807b 100644 --- a/api/tasks/mail_change_mail_task.py +++ b/api/tasks/mail_change_mail_task.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="mail") -def send_change_mail_task(language: str, to: str, code: str, phase: str) -> None: +def send_change_mail_task(language: str, to: str, code: str, phase: str): """ Send change email notification with internationalization support. @@ -43,7 +43,7 @@ def send_change_mail_task(language: str, to: str, code: str, phase: str) -> None @shared_task(queue="mail") -def send_change_mail_completed_notification_task(language: str, to: str) -> None: +def send_change_mail_completed_notification_task(language: str, to: str): """ Send change email completed notification with internationalization support. diff --git a/api/tasks/mail_email_code_login.py b/api/tasks/mail_email_code_login.py index 126c169d0..e97eae92d 100644 --- a/api/tasks/mail_email_code_login.py +++ b/api/tasks/mail_email_code_login.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="mail") -def send_email_code_login_mail_task(language: str, to: str, code: str) -> None: +def send_email_code_login_mail_task(language: str, to: str, code: str): """ Send email code login email with internationalization support. diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py index a5d59d745..8b091fe0b 100644 --- a/api/tasks/mail_invite_member_task.py +++ b/api/tasks/mail_invite_member_task.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="mail") -def send_invite_member_mail_task(language: str, to: str, token: str, inviter_name: str, workspace_name: str) -> None: +def send_invite_member_mail_task(language: str, to: str, token: str, inviter_name: str, workspace_name: str): """ Send invite member email with internationalization support. diff --git a/api/tasks/mail_owner_transfer_task.py b/api/tasks/mail_owner_transfer_task.py index 33a8e1743..6a72dde2f 100644 --- a/api/tasks/mail_owner_transfer_task.py +++ b/api/tasks/mail_owner_transfer_task.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="mail") -def send_owner_transfer_confirm_task(language: str, to: str, code: str, workspace: str) -> None: +def send_owner_transfer_confirm_task(language: str, to: str, code: str, workspace: str): """ Send owner transfer confirmation email with internationalization support. @@ -52,7 +52,7 @@ def send_owner_transfer_confirm_task(language: str, to: str, code: str, workspac @shared_task(queue="mail") -def send_old_owner_transfer_notify_email_task(language: str, to: str, workspace: str, new_owner_email: str) -> None: +def send_old_owner_transfer_notify_email_task(language: str, to: str, workspace: str, new_owner_email: str): """ Send old owner transfer notification email with internationalization support. @@ -93,7 +93,7 @@ def send_old_owner_transfer_notify_email_task(language: str, to: str, workspace: @shared_task(queue="mail") -def send_new_owner_transfer_notify_email_task(language: str, to: str, workspace: str) -> None: +def send_new_owner_transfer_notify_email_task(language: str, to: str, workspace: str): """ Send new owner transfer notification email with internationalization support. diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py index 1fcc2bfba..545db84fd 100644 --- a/api/tasks/mail_reset_password_task.py +++ b/api/tasks/mail_reset_password_task.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="mail") -def send_reset_password_mail_task(language: str, to: str, code: str) -> None: +def send_reset_password_mail_task(language: str, to: str, code: str): """ Send reset password email with internationalization support. diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 7bfda3d74..241e04e4d 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -395,7 +395,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int: return total_deleted -def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None: +def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str): while True: with db.engine.begin() as conn: rs = conn.execute(sa.text(query_sql), params) diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py index 77ddf8302..7d145fb50 100644 --- a/api/tasks/workflow_execution_tasks.py +++ b/api/tasks/workflow_execution_tasks.py @@ -120,7 +120,7 @@ def _create_workflow_run_from_execution( return workflow_run -def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution) -> None: +def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution): """ Update a WorkflowRun database model from a WorkflowExecution domain entity. """ diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index 16356086c..8f5127670 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -140,9 +140,7 @@ def _create_node_execution_from_domain( return node_execution -def _update_node_execution_from_domain( - node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution -) -> None: +def _update_node_execution_from_domain(node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution): """ Update a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity. """ diff --git a/api/tests/integration_tests/conftest.py b/api/tests/integration_tests/conftest.py index d9f90f992..597e7330b 100644 --- a/api/tests/integration_tests/conftest.py +++ b/api/tests/integration_tests/conftest.py @@ -14,7 +14,7 @@ from services.account_service import AccountService, RegisterService # Loading the .env file if it exists -def _load_env() -> None: +def _load_env(): current_file_path = pathlib.Path(__file__).absolute() # Items later in the list have higher precedence. files_to_load = [".env", "vdb.env"] diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_daemon.py b/api/tests/integration_tests/model_runtime/__mock/plugin_daemon.py index c8cb7528e..d4cd5df55 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_daemon.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_daemon.py @@ -17,7 +17,7 @@ def mock_plugin_daemon( :return: unpatch function """ - def unpatch() -> None: + def unpatch(): monkeypatch.undo() monkeypatch.setattr(PluginModelClient, "invoke_llm", MockModelClass.invoke_llm) diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index 02f658aad..fd7ab0a22 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -150,7 +150,7 @@ class MockTcvectordbClass: filter: Optional[Filter] = None, output_fields: Optional[list[str]] = None, timeout: Optional[float] = None, - ) -> list[dict]: + ): return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}] def collection_delete( @@ -163,7 +163,7 @@ class MockTcvectordbClass: ): return {"code": 0, "msg": "operation success"} - def drop_collection(self, database_name: str, collection_name: str, timeout: Optional[float] = None) -> dict: + def drop_collection(self, database_name: str, collection_name: str, timeout: Optional[float] = None): return {"code": 0, "msg": "operation success"} diff --git a/api/tests/integration_tests/vdb/test_vector_store.py b/api/tests/integration_tests/vdb/test_vector_store.py index 50519e205..a033443cf 100644 --- a/api/tests/integration_tests/vdb/test_vector_store.py +++ b/api/tests/integration_tests/vdb/test_vector_store.py @@ -26,7 +26,7 @@ def get_example_document(doc_id: str) -> Document: @pytest.fixture -def setup_mock_redis() -> None: +def setup_mock_redis(): # get ext_redis.redis_client.get = MagicMock(return_value=None) @@ -48,7 +48,7 @@ class AbstractVectorTest: self.example_doc_id = str(uuid.uuid4()) self.example_embedding = [1.001 * i for i in range(128)] - def create_vector(self) -> None: + def create_vector(self): self.vector.create( texts=[get_example_document(doc_id=self.example_doc_id)], embeddings=[self.example_embedding], diff --git a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py index 30414811e..bdd2f5afd 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py @@ -12,7 +12,7 @@ MOCK = os.getenv("MOCK_SWITCH", "false") == "true" class MockedCodeExecutor: @classmethod - def invoke(cls, language: Literal["python3", "javascript", "jinja2"], code: str, inputs: dict) -> dict: + def invoke(cls, language: Literal["python3", "javascript", "jinja2"], code: str, inputs: dict): # invoke directly match language: case CodeLanguage.PYTHON3: diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index eb85d6118..7c6e52899 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -74,7 +74,7 @@ def init_code_node(code_config: dict): @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) def test_execute_code(setup_code_executor_mock): code = """ - def main(args1: int, args2: int) -> dict: + def main(args1: int, args2: int): return { "result": args1 + args2, } @@ -120,7 +120,7 @@ def test_execute_code(setup_code_executor_mock): @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) def test_execute_code_output_validator(setup_code_executor_mock): code = """ - def main(args1: int, args2: int) -> dict: + def main(args1: int, args2: int): return { "result": args1 + args2, } @@ -163,7 +163,7 @@ def test_execute_code_output_validator(setup_code_executor_mock): def test_execute_code_output_validator_depth(): code = """ - def main(args1: int, args2: int) -> dict: + def main(args1: int, args2: int): return { "result": { "result": args1 + args2, @@ -281,7 +281,7 @@ def test_execute_code_output_validator_depth(): def test_execute_code_output_object_list(): code = """ - def main(args1: int, args2: int) -> dict: + def main(args1: int, args2: int): return { "result": { "result": args1 + args2, @@ -356,7 +356,7 @@ def test_execute_code_output_object_list(): def test_execute_code_scientific_notation(): code = """ - def main() -> dict: + def main(): return { "result": -8.0E-5 } diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index 66ddc0ba4..f28437f6c 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -49,7 +49,7 @@ class DifyTestContainers: self._containers_started = False logger.info("DifyTestContainers initialized - ready to manage test containers") - def start_containers_with_env(self) -> None: + def start_containers_with_env(self): """ Start all required containers for integration testing. @@ -230,7 +230,7 @@ class DifyTestContainers: self._containers_started = True logger.info("All test containers started successfully") - def stop_containers(self) -> None: + def stop_containers(self): """ Stop and clean up all test containers. diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index b88a57bfd..5895f63f9 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -23,7 +23,7 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: storage_key="storage_key_123", ) - def create_file_dict(self, file_id: str = "test_file_dict") -> dict: + def create_file_dict(self, file_id: str = "test_file_dict"): """Create a file dictionary with correct dify_model_identity""" return { "dify_model_identity": FILE_MODEL_IDENTITY, diff --git a/api/tests/unit_tests/core/mcp/client/test_session.py b/api/tests/unit_tests/core/mcp/client/test_session.py index c84169bf1..08d5b7d21 100644 --- a/api/tests/unit_tests/core/mcp/client/test_session.py +++ b/api/tests/unit_tests/core/mcp/client/test_session.py @@ -83,7 +83,7 @@ def test_client_session_initialize(): # Create message handler def message_handler( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, - ) -> None: + ): if isinstance(message, Exception): raise message diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index ed4e42425..0bf4fa7ee 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -777,7 +777,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app): }, { "data": { - "code": '\ndef main(arg1: str, arg2: str) -> dict:\n return {\n "result": arg1 + arg2,\n }\n', # noqa: E501 + "code": '\ndef main(arg1: str, arg2: str):\n return {\n "result": arg1 + arg2,\n }\n', "code_language": "python3", "desc": "", "outputs": {"result": {"children": None, "type": "string"}}, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py index 3f8342883..d045ac5e4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -233,7 +233,7 @@ FAIL_BRANCH_EDGES = [ def test_code_default_value_continue_on_error(): error_code = """ - def main() -> dict: + def main(): return { "result": 1 / 0, } @@ -259,7 +259,7 @@ def test_code_default_value_continue_on_error(): def test_code_fail_branch_continue_on_error(): error_code = """ - def main() -> dict: + def main(): return { "result": 1 / 0, } diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 36a6fbb53..dc0524f43 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -276,7 +276,7 @@ def test_array_file_contains_file_name(): assert result.outputs["result"] is True -def _get_test_conditions() -> list: +def _get_test_conditions(): conditions = [ # Test boolean "is" operator {"comparison_operator": "is", "variable_selector": ["start", "bool_true"], "value": "true"}, diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index c0330b944..0be85abfa 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -379,7 +379,7 @@ class TestVariablePoolSerialization: self._assert_pools_equal(reconstructed_dict, reconstructed_json) # TODO: assert the data for file object... - def _assert_pools_equal(self, pool1: VariablePool, pool2: VariablePool) -> None: + def _assert_pools_equal(self, pool1: VariablePool, pool2: VariablePool): """Assert that two VariablePools contain equivalent data""" # Compare system variables diff --git a/api/tests/unit_tests/libs/test_email_i18n.py b/api/tests/unit_tests/libs/test_email_i18n.py index aeb30438e..b80c711ca 100644 --- a/api/tests/unit_tests/libs/test_email_i18n.py +++ b/api/tests/unit_tests/libs/test_email_i18n.py @@ -27,7 +27,7 @@ from services.feature_service import BrandingModel class MockEmailRenderer: """Mock implementation of EmailRenderer protocol""" - def __init__(self) -> None: + def __init__(self): self.rendered_templates: list[tuple[str, dict[str, Any]]] = [] def render_template(self, template_path: str, **context: Any) -> str: @@ -39,7 +39,7 @@ class MockEmailRenderer: class MockBrandingService: """Mock implementation of BrandingService protocol""" - def __init__(self, enabled: bool = False, application_title: str = "Dify") -> None: + def __init__(self, enabled: bool = False, application_title: str = "Dify"): self.enabled = enabled self.application_title = application_title @@ -54,10 +54,10 @@ class MockBrandingService: class MockEmailSender: """Mock implementation of EmailSender protocol""" - def __init__(self) -> None: + def __init__(self): self.sent_emails: list[dict[str, str]] = [] - def send_email(self, to: str, subject: str, html_content: str) -> None: + def send_email(self, to: str, subject: str, html_content: str): """Mock send_email that records sent emails""" self.sent_emails.append( { @@ -134,7 +134,7 @@ class TestEmailI18nService: email_service: EmailI18nService, mock_renderer: MockEmailRenderer, mock_sender: MockEmailSender, - ) -> None: + ): """Test sending email with English language""" email_service.send_email( email_type=EmailType.RESET_PASSWORD, @@ -162,7 +162,7 @@ class TestEmailI18nService: self, email_service: EmailI18nService, mock_sender: MockEmailSender, - ) -> None: + ): """Test sending email with Chinese language""" email_service.send_email( email_type=EmailType.RESET_PASSWORD, @@ -181,7 +181,7 @@ class TestEmailI18nService: email_config: EmailI18nConfig, mock_renderer: MockEmailRenderer, mock_sender: MockEmailSender, - ) -> None: + ): """Test sending email with branding enabled""" # Create branding service with branding enabled branding_service = MockBrandingService(enabled=True, application_title="MyApp") @@ -215,7 +215,7 @@ class TestEmailI18nService: self, email_service: EmailI18nService, mock_sender: MockEmailSender, - ) -> None: + ): """Test language fallback to English when requested language not available""" # Request invite member in Chinese (not configured) email_service.send_email( @@ -233,7 +233,7 @@ class TestEmailI18nService: self, email_service: EmailI18nService, mock_sender: MockEmailSender, - ) -> None: + ): """Test unknown language code falls back to English""" email_service.send_email( email_type=EmailType.RESET_PASSWORD, @@ -252,7 +252,7 @@ class TestEmailI18nService: mock_renderer: MockEmailRenderer, mock_sender: MockEmailSender, mock_branding_service: MockBrandingService, - ) -> None: + ): """Test sending change email for old email verification""" # Add change email templates to config email_config.templates[EmailType.CHANGE_EMAIL_OLD] = { @@ -290,7 +290,7 @@ class TestEmailI18nService: mock_renderer: MockEmailRenderer, mock_sender: MockEmailSender, mock_branding_service: MockBrandingService, - ) -> None: + ): """Test sending change email for new email verification""" # Add change email templates to config email_config.templates[EmailType.CHANGE_EMAIL_NEW] = { @@ -325,7 +325,7 @@ class TestEmailI18nService: def test_send_change_email_invalid_phase( self, email_service: EmailI18nService, - ) -> None: + ): """Test sending change email with invalid phase raises error""" with pytest.raises(ValueError, match="Invalid phase: invalid_phase"): email_service.send_change_email( @@ -339,7 +339,7 @@ class TestEmailI18nService: self, email_service: EmailI18nService, mock_sender: MockEmailSender, - ) -> None: + ): """Test sending raw email to single recipient""" email_service.send_raw_email( to="test@example.com", @@ -357,7 +357,7 @@ class TestEmailI18nService: self, email_service: EmailI18nService, mock_sender: MockEmailSender, - ) -> None: + ): """Test sending raw email to multiple recipients""" recipients = ["user1@example.com", "user2@example.com", "user3@example.com"] @@ -378,7 +378,7 @@ class TestEmailI18nService: def test_get_template_missing_email_type( self, email_config: EmailI18nConfig, - ) -> None: + ): """Test getting template for missing email type raises error""" with pytest.raises(ValueError, match="No templates configured for email type"): email_config.get_template(EmailType.EMAIL_CODE_LOGIN, EmailLanguage.EN_US) @@ -386,7 +386,7 @@ class TestEmailI18nService: def test_get_template_missing_language_and_english( self, email_config: EmailI18nConfig, - ) -> None: + ): """Test error when neither requested language nor English fallback exists""" # Add template without English fallback email_config.templates[EmailType.EMAIL_CODE_LOGIN] = { @@ -407,7 +407,7 @@ class TestEmailI18nService: mock_renderer: MockEmailRenderer, mock_sender: MockEmailSender, mock_branding_service: MockBrandingService, - ) -> None: + ): """Test subject templating with custom variables""" # Add template with variable in subject email_config.templates[EmailType.OWNER_TRANSFER_NEW_NOTIFY] = { @@ -437,7 +437,7 @@ class TestEmailI18nService: sent_email = mock_sender.sent_emails[0] assert sent_email["subject"] == "You are now the owner of My Workspace" - def test_email_language_from_language_code(self) -> None: + def test_email_language_from_language_code(self): """Test EmailLanguage.from_language_code method""" assert EmailLanguage.from_language_code("zh-Hans") == EmailLanguage.ZH_HANS assert EmailLanguage.from_language_code("en-US") == EmailLanguage.EN_US @@ -448,7 +448,7 @@ class TestEmailI18nService: class TestEmailI18nIntegration: """Integration tests for email i18n components""" - def test_create_default_email_config(self) -> None: + def test_create_default_email_config(self): """Test creating default email configuration""" config = create_default_email_config() @@ -476,7 +476,7 @@ class TestEmailI18nIntegration: assert EmailLanguage.ZH_HANS in config.templates[EmailType.RESET_PASSWORD] assert EmailLanguage.ZH_HANS in config.templates[EmailType.INVITE_MEMBER] - def test_get_email_i18n_service(self) -> None: + def test_get_email_i18n_service(self): """Test getting global email i18n service instance""" service1 = get_email_i18n_service() service2 = get_email_i18n_service() @@ -484,7 +484,7 @@ class TestEmailI18nIntegration: # Should return the same instance assert service1 is service2 - def test_flask_email_renderer(self) -> None: + def test_flask_email_renderer(self): """Test FlaskEmailRenderer implementation""" renderer = FlaskEmailRenderer() @@ -494,7 +494,7 @@ class TestEmailI18nIntegration: with pytest.raises(TemplateNotFound): renderer.render_template("test.html", foo="bar") - def test_flask_mail_sender_not_initialized(self) -> None: + def test_flask_mail_sender_not_initialized(self): """Test FlaskMailSender when mail is not initialized""" sender = FlaskMailSender() @@ -514,7 +514,7 @@ class TestEmailI18nIntegration: # Restore original mail libs.email_i18n.mail = original_mail - def test_flask_mail_sender_initialized(self) -> None: + def test_flask_mail_sender_initialized(self): """Test FlaskMailSender when mail is initialized""" sender = FlaskMailSender() diff --git a/api/tests/unit_tests/libs/test_rsa.py b/api/tests/unit_tests/libs/test_rsa.py index 2dc51252f..6a448d4f1 100644 --- a/api/tests/unit_tests/libs/test_rsa.py +++ b/api/tests/unit_tests/libs/test_rsa.py @@ -4,7 +4,7 @@ from Crypto.PublicKey import RSA from libs import gmpy2_pkcs10aep_cipher -def test_gmpy2_pkcs10aep_cipher() -> None: +def test_gmpy2_pkcs10aep_cipher(): rsa_key_pair = pyrsa.newkeys(2048) public_key = rsa_key_pair[0].save_pkcs1() private_key = rsa_key_pair[1].save_pkcs1() diff --git a/api/tests/unit_tests/models/test_account.py b/api/tests/unit_tests/models/test_account.py index 026912ffb..f555fc58d 100644 --- a/api/tests/unit_tests/models/test_account.py +++ b/api/tests/unit_tests/models/test_account.py @@ -1,7 +1,7 @@ from models.account import TenantAccountRole -def test_account_is_privileged_role() -> None: +def test_account_is_privileged_role(): assert TenantAccountRole.ADMIN == "admin" assert TenantAccountRole.OWNER == "owner" assert TenantAccountRole.EDITOR == "editor"