chore: apply ruff's pyupgrade linter rules to modernize Python code with targeted version (#2419)

This commit is contained in:
Bowen Liang
2024-02-09 15:21:33 +08:00
committed by GitHub
parent 589099a005
commit 063191889d
246 changed files with 912 additions and 937 deletions

View File

@@ -3,7 +3,7 @@ import json
import logging
import mimetypes
from os import listdir, path
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Union
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.model_runtime.entities.message_entities import PromptMessage
@@ -35,10 +35,10 @@ class ToolManager:
provider: str,
tool_id: str,
tool_name: str,
tool_parameters: Dict[str, Any],
credentials: Dict[str, Any],
prompt_messages: List[PromptMessage],
) -> List[ToolInvokeMessage]:
tool_parameters: dict[str, Any],
credentials: dict[str, Any],
prompt_messages: list[PromptMessage],
) -> list[ToolInvokeMessage]:
"""
invoke the assistant
@@ -200,7 +200,7 @@ class ToolManager:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
@staticmethod
def get_builtin_provider_icon(provider: str) -> Tuple[str, str]:
def get_builtin_provider_icon(provider: str) -> tuple[str, str]:
"""
get the absolute path of the icon of the builtin provider
@@ -223,14 +223,14 @@ class ToolManager:
return absolute_path, mime_type
@staticmethod
def list_builtin_providers() -> List[BuiltinToolProviderController]:
def list_builtin_providers() -> list[BuiltinToolProviderController]:
global _builtin_providers
# use cache first
if len(_builtin_providers) > 0:
return list(_builtin_providers.values())
builtin_providers: List[BuiltinToolProviderController] = []
builtin_providers: list[BuiltinToolProviderController] = []
for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
if provider.startswith('__'):
continue
@@ -289,8 +289,8 @@ class ToolManager:
def user_list_providers(
user_id: str,
tenant_id: str,
) -> List[UserToolProvider]:
result_providers: Dict[str, UserToolProvider] = {}
) -> list[UserToolProvider]:
result_providers: dict[str, UserToolProvider] = {}
# get builtin providers
builtin_providers = ToolManager.list_builtin_providers()
# append builtin providers
@@ -325,7 +325,7 @@ class ToolManager:
result_providers[provider.identity.name].allow_delete = False
# get db builtin providers
db_builtin_providers: List[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
filter(BuiltinToolProvider.tenant_id == tenant_id).all()
for db_builtin_provider in db_builtin_providers:
@@ -346,7 +346,7 @@ class ToolManager:
result_providers[provider_name].team_credentials = masked_credentials
# get db api providers
db_api_providers: List[ApiToolProvider] = db.session.query(ApiToolProvider). \
db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
filter(ApiToolProvider.tenant_id == tenant_id).all()
for db_api_provider in db_api_providers:
@@ -394,7 +394,7 @@ class ToolManager:
return BuiltinToolProviderSort.sort(list(result_providers.values()))
@staticmethod
def get_api_provider_controller(tenant_id: str, provider_id: str) -> Tuple[ApiBasedToolProviderController, Dict[str, Any]]:
def get_api_provider_controller(tenant_id: str, provider_id: str) -> tuple[ApiBasedToolProviderController, dict[str, Any]]:
"""
get the api provider