fix: resolve typing errors in configs module (#25268)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
-LAN-
2025-09-06 16:08:14 +08:00
committed by GitHub
parent e41e23481c
commit b05245eab0
11 changed files with 77 additions and 65 deletions

View File

@@ -300,8 +300,7 @@ class DatasetQueueMonitorConfig(BaseSettings):
class MiddlewareConfig( class MiddlewareConfig(
# place the configs in alphabet order # place the configs in alphabet order
CeleryConfig, CeleryConfig, # Note: CeleryConfig already inherits from DatabaseConfig
DatabaseConfig,
KeywordStoreConfig, KeywordStoreConfig,
RedisConfig, RedisConfig,
# configs of storage and storage providers # configs of storage and storage providers

View File

@@ -1,9 +1,10 @@
from typing import Optional from typing import Optional
from pydantic import BaseModel, Field from pydantic import Field
from pydantic_settings import BaseSettings
class ClickzettaConfig(BaseModel): class ClickzettaConfig(BaseSettings):
""" """
Clickzetta Lakehouse vector database configuration Clickzetta Lakehouse vector database configuration
""" """

View File

@@ -1,7 +1,8 @@
from pydantic import BaseModel, Field from pydantic import Field
from pydantic_settings import BaseSettings
class MatrixoneConfig(BaseModel): class MatrixoneConfig(BaseSettings):
"""Matrixone vector database configuration.""" """Matrixone vector database configuration."""
MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server") MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server")

View File

@@ -1,6 +1,6 @@
from pydantic import Field from pydantic import Field
from configs.packaging.pyproject import PyProjectConfig, PyProjectTomlConfig from configs.packaging.pyproject import PyProjectTomlConfig
class PackagingInfo(PyProjectTomlConfig): class PackagingInfo(PyProjectTomlConfig):

View File

@@ -4,8 +4,9 @@ import logging
import os import os
import threading import threading
import time import time
from collections.abc import Mapping from collections.abc import Callable, Mapping
from pathlib import Path from pathlib import Path
from typing import Any
from .python_3x import http_request, makedirs_wrapper from .python_3x import http_request, makedirs_wrapper
from .utils import ( from .utils import (
@@ -25,13 +26,13 @@ logger = logging.getLogger(__name__)
class ApolloClient: class ApolloClient:
def __init__( def __init__(
self, self,
config_url, config_url: str,
app_id, app_id: str,
cluster="default", cluster: str = "default",
secret="", secret: str = "",
start_hot_update=True, start_hot_update: bool = True,
change_listener=None, change_listener: Callable[[str, str, str, Any], None] | None = None,
_notification_map=None, _notification_map: dict[str, int] | None = None,
): ):
# Core routing parameters # Core routing parameters
self.config_url = config_url self.config_url = config_url
@@ -47,17 +48,17 @@ class ApolloClient:
# Private control variables # Private control variables
self._cycle_time = 5 self._cycle_time = 5
self._stopping = False self._stopping = False
self._cache = {} self._cache: dict[str, dict[str, Any]] = {}
self._no_key = {} self._no_key: dict[str, str] = {}
self._hash = {} self._hash: dict[str, str] = {}
self._pull_timeout = 75 self._pull_timeout = 75
self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/" self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/"
self._long_poll_thread = None self._long_poll_thread: threading.Thread | None = None
self._change_listener = change_listener # "add" "delete" "update" self._change_listener = change_listener # "add" "delete" "update"
if _notification_map is None: if _notification_map is None:
_notification_map = {"application": -1} _notification_map = {"application": -1}
self._notification_map = _notification_map self._notification_map = _notification_map
self.last_release_key = None self.last_release_key: str | None = None
# Private startup method # Private startup method
self._path_checker() self._path_checker()
if start_hot_update: if start_hot_update:
@@ -68,7 +69,7 @@ class ApolloClient:
heartbeat.daemon = True heartbeat.daemon = True
heartbeat.start() heartbeat.start()
def get_json_from_net(self, namespace="application"): def get_json_from_net(self, namespace: str = "application") -> dict[str, Any] | None:
url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format( url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format(
self.config_url, self.app_id, self.cluster, namespace, "", self.ip self.config_url, self.app_id, self.cluster, namespace, "", self.ip
) )
@@ -88,7 +89,7 @@ class ApolloClient:
logger.exception("an error occurred in get_json_from_net") logger.exception("an error occurred in get_json_from_net")
return None return None
def get_value(self, key, default_val=None, namespace="application"): def get_value(self, key: str, default_val: Any = None, namespace: str = "application") -> Any:
try: try:
# read memory configuration # read memory configuration
namespace_cache = self._cache.get(namespace) namespace_cache = self._cache.get(namespace)
@@ -104,6 +105,7 @@ class ApolloClient:
namespace_data = self.get_json_from_net(namespace) namespace_data = self.get_json_from_net(namespace)
val = get_value_from_dict(namespace_data, key) val = get_value_from_dict(namespace_data, key)
if val is not None: if val is not None:
if namespace_data is not None:
self._update_cache_and_file(namespace_data, namespace) self._update_cache_and_file(namespace_data, namespace)
return val return val
@@ -126,23 +128,23 @@ class ApolloClient:
# to ensure the real-time correctness of the function call. # to ensure the real-time correctness of the function call.
# If the user does not have the same default val twice # If the user does not have the same default val twice
# and the default val is used here, there may be a problem. # and the default val is used here, there may be a problem.
def _set_local_cache_none(self, namespace, key): def _set_local_cache_none(self, namespace: str, key: str) -> None:
no_key = no_key_cache_key(namespace, key) no_key = no_key_cache_key(namespace, key)
self._no_key[no_key] = key self._no_key[no_key] = key
def _start_hot_update(self): def _start_hot_update(self) -> None:
self._long_poll_thread = threading.Thread(target=self._listener) self._long_poll_thread = threading.Thread(target=self._listener)
# When the asynchronous thread is started, the daemon thread will automatically exit # When the asynchronous thread is started, the daemon thread will automatically exit
# when the main thread is launched. # when the main thread is launched.
self._long_poll_thread.daemon = True self._long_poll_thread.daemon = True
self._long_poll_thread.start() self._long_poll_thread.start()
def stop(self): def stop(self) -> None:
self._stopping = True self._stopping = True
logger.info("Stopping listener...") logger.info("Stopping listener...")
# Call the set callback function, and if it is abnormal, try it out # Call the set callback function, and if it is abnormal, try it out
def _call_listener(self, namespace, old_kv, new_kv): def _call_listener(self, namespace: str, old_kv: dict[str, Any] | None, new_kv: dict[str, Any] | None) -> None:
if self._change_listener is None: if self._change_listener is None:
return return
if old_kv is None: if old_kv is None:
@@ -168,12 +170,12 @@ class ApolloClient:
except BaseException as e: except BaseException as e:
logger.warning(str(e)) logger.warning(str(e))
def _path_checker(self): def _path_checker(self) -> None:
if not os.path.isdir(self._cache_file_path): if not os.path.isdir(self._cache_file_path):
makedirs_wrapper(self._cache_file_path) makedirs_wrapper(self._cache_file_path)
# update the local cache and file cache # update the local cache and file cache
def _update_cache_and_file(self, namespace_data, namespace="application"): def _update_cache_and_file(self, namespace_data: dict[str, Any], namespace: str = "application") -> None:
# update the local cache # update the local cache
self._cache[namespace] = namespace_data self._cache[namespace] = namespace_data
# update the file cache # update the file cache
@@ -187,7 +189,7 @@ class ApolloClient:
self._hash[namespace] = new_hash self._hash[namespace] = new_hash
# get the configuration from the local file # get the configuration from the local file
def _get_local_cache(self, namespace="application"): def _get_local_cache(self, namespace: str = "application") -> dict[str, Any]:
cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt") cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt")
if os.path.isfile(cache_file_path): if os.path.isfile(cache_file_path):
with open(cache_file_path) as f: with open(cache_file_path) as f:
@@ -195,8 +197,8 @@ class ApolloClient:
return result return result
return {} return {}
def _long_poll(self): def _long_poll(self) -> None:
notifications = [] notifications: list[dict[str, Any]] = []
for key in self._cache: for key in self._cache:
namespace_data = self._cache[key] namespace_data = self._cache[key]
notification_id = -1 notification_id = -1
@@ -236,7 +238,7 @@ class ApolloClient:
except Exception as e: except Exception as e:
logger.warning(str(e)) logger.warning(str(e))
def _get_net_and_set_local(self, namespace, n_id, call_change=False): def _get_net_and_set_local(self, namespace: str, n_id: int, call_change: bool = False) -> None:
namespace_data = self.get_json_from_net(namespace) namespace_data = self.get_json_from_net(namespace)
if not namespace_data: if not namespace_data:
return return
@@ -248,7 +250,7 @@ class ApolloClient:
new_kv = namespace_data.get(CONFIGURATIONS) new_kv = namespace_data.get(CONFIGURATIONS)
self._call_listener(namespace, old_kv, new_kv) self._call_listener(namespace, old_kv, new_kv)
def _listener(self): def _listener(self) -> None:
logger.info("start long_poll") logger.info("start long_poll")
while not self._stopping: while not self._stopping:
self._long_poll() self._long_poll()
@@ -266,13 +268,13 @@ class ApolloClient:
headers["Timestamp"] = time_unix_now headers["Timestamp"] = time_unix_now
return headers return headers
def _heart_beat(self): def _heart_beat(self) -> None:
while not self._stopping: while not self._stopping:
for namespace in self._notification_map: for namespace in self._notification_map:
self._do_heart_beat(namespace) self._do_heart_beat(namespace)
time.sleep(60 * 10) # 10 minutes time.sleep(60 * 10) # 10 minutes
def _do_heart_beat(self, namespace): def _do_heart_beat(self, namespace: str) -> None:
url = f"{self.config_url}/configs/{self.app_id}/{self.cluster}/{namespace}?ip={self.ip}" url = f"{self.config_url}/configs/{self.app_id}/{self.cluster}/{namespace}?ip={self.ip}"
try: try:
code, body = http_request(url, timeout=3, headers=self._sign_headers(url)) code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
@@ -292,7 +294,7 @@ class ApolloClient:
logger.exception("an error occurred in _do_heart_beat") logger.exception("an error occurred in _do_heart_beat")
return None return None
def get_all_dicts(self, namespace): def get_all_dicts(self, namespace: str) -> dict[str, Any] | None:
namespace_data = self._cache.get(namespace) namespace_data = self._cache.get(namespace)
if namespace_data is None: if namespace_data is None:
net_namespace_data = self.get_json_from_net(namespace) net_namespace_data = self.get_json_from_net(namespace)

View File

@@ -2,6 +2,8 @@ import logging
import os import os
import ssl import ssl
import urllib.request import urllib.request
from collections.abc import Mapping
from typing import Any
from urllib import parse from urllib import parse
from urllib.error import HTTPError from urllib.error import HTTPError
@@ -19,9 +21,9 @@ urllib.request.install_opener(opener)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def http_request(url, timeout, headers={}): def http_request(url: str, timeout: int | float, headers: Mapping[str, str] = {}) -> tuple[int, str | None]:
try: try:
request = urllib.request.Request(url, headers=headers) request = urllib.request.Request(url, headers=dict(headers))
res = urllib.request.urlopen(request, timeout=timeout) res = urllib.request.urlopen(request, timeout=timeout)
body = res.read().decode("utf-8") body = res.read().decode("utf-8")
return res.code, body return res.code, body
@@ -33,9 +35,9 @@ def http_request(url, timeout, headers={}):
raise e raise e
def url_encode(params): def url_encode(params: dict[str, Any]) -> str:
return parse.urlencode(params) return parse.urlencode(params)
def makedirs_wrapper(path): def makedirs_wrapper(path: str) -> None:
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)

View File

@@ -1,5 +1,6 @@
import hashlib import hashlib
import socket import socket
from typing import Any
from .python_3x import url_encode from .python_3x import url_encode
@@ -10,7 +11,7 @@ NAMESPACE_NAME = "namespaceName"
# add timestamps uris and keys # add timestamps uris and keys
def signature(timestamp, uri, secret): def signature(timestamp: str, uri: str, secret: str) -> str:
import base64 import base64
import hmac import hmac
@@ -19,16 +20,16 @@ def signature(timestamp, uri, secret):
return base64.b64encode(hmac_code).decode() return base64.b64encode(hmac_code).decode()
def url_encode_wrapper(params): def url_encode_wrapper(params: dict[str, Any]) -> str:
return url_encode(params) return url_encode(params)
def no_key_cache_key(namespace, key): def no_key_cache_key(namespace: str, key: str) -> str:
return f"{namespace}{len(namespace)}{key}" return f"{namespace}{len(namespace)}{key}"
# Returns whether the obtained value is obtained, and None if it does not # Returns whether the obtained value is obtained, and None if it does not
def get_value_from_dict(namespace_cache, key): def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any | None:
if namespace_cache: if namespace_cache:
kv_data = namespace_cache.get(CONFIGURATIONS) kv_data = namespace_cache.get(CONFIGURATIONS)
if kv_data is None: if kv_data is None:
@@ -38,7 +39,7 @@ def get_value_from_dict(namespace_cache, key):
return None return None
def init_ip(): def init_ip() -> str:
ip = "" ip = ""
s = None s = None
try: try:

View File

@@ -11,16 +11,16 @@ logger = logging.getLogger(__name__)
from configs.remote_settings_sources.base import RemoteSettingsSource from configs.remote_settings_sources.base import RemoteSettingsSource
from .utils import _parse_config from .utils import parse_config
class NacosSettingsSource(RemoteSettingsSource): class NacosSettingsSource(RemoteSettingsSource):
def __init__(self, configs: Mapping[str, Any]): def __init__(self, configs: Mapping[str, Any]):
self.configs = configs self.configs = configs
self.remote_configs: dict[str, Any] = {} self.remote_configs: dict[str, str] = {}
self.async_init() self.async_init()
def async_init(self): def async_init(self) -> None:
data_id = os.getenv("DIFY_ENV_NACOS_DATA_ID", "dify-api-env.properties") data_id = os.getenv("DIFY_ENV_NACOS_DATA_ID", "dify-api-env.properties")
group = os.getenv("DIFY_ENV_NACOS_GROUP", "nacos-dify") group = os.getenv("DIFY_ENV_NACOS_GROUP", "nacos-dify")
tenant = os.getenv("DIFY_ENV_NACOS_NAMESPACE", "") tenant = os.getenv("DIFY_ENV_NACOS_NAMESPACE", "")
@@ -33,18 +33,15 @@ class NacosSettingsSource(RemoteSettingsSource):
logger.exception("[get-access-token] exception occurred") logger.exception("[get-access-token] exception occurred")
raise raise
def _parse_config(self, content: str): def _parse_config(self, content: str) -> dict[str, str]:
if not content: if not content:
return {} return {}
try: try:
return _parse_config(self, content) return parse_config(content)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to parse config: {e}") raise RuntimeError(f"Failed to parse config: {e}")
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
if not isinstance(self.remote_configs, dict):
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
field_value = self.remote_configs.get(field_name) field_value = self.remote_configs.get(field_name)
if field_value is None: if field_value is None:
return None, field_name, False return None, field_name, False

View File

@@ -17,11 +17,17 @@ class NacosHttpClient:
self.ak = os.getenv("DIFY_ENV_NACOS_ACCESS_KEY") self.ak = os.getenv("DIFY_ENV_NACOS_ACCESS_KEY")
self.sk = os.getenv("DIFY_ENV_NACOS_SECRET_KEY") self.sk = os.getenv("DIFY_ENV_NACOS_SECRET_KEY")
self.server = os.getenv("DIFY_ENV_NACOS_SERVER_ADDR", "localhost:8848") self.server = os.getenv("DIFY_ENV_NACOS_SERVER_ADDR", "localhost:8848")
self.token = None self.token: str | None = None
self.token_ttl = 18000 self.token_ttl = 18000
self.token_expire_time: float = 0 self.token_expire_time: float = 0
def http_request(self, url, method="GET", headers=None, params=None): def http_request(
self, url: str, method: str = "GET", headers: dict[str, str] | None = None, params: dict[str, str] | None = None
) -> str:
if headers is None:
headers = {}
if params is None:
params = {}
try: try:
self._inject_auth_info(headers, params) self._inject_auth_info(headers, params)
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params) response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
@@ -30,7 +36,7 @@ class NacosHttpClient:
except requests.RequestException as e: except requests.RequestException as e:
return f"Request to Nacos failed: {e}" return f"Request to Nacos failed: {e}"
def _inject_auth_info(self, headers, params, module="config"): def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None:
headers.update({"User-Agent": "Nacos-Http-Client-In-Dify:v0.0.1"}) headers.update({"User-Agent": "Nacos-Http-Client-In-Dify:v0.0.1"})
if module == "login": if module == "login":
@@ -45,16 +51,17 @@ class NacosHttpClient:
headers["timeStamp"] = ts headers["timeStamp"] = ts
if self.username and self.password: if self.username and self.password:
self.get_access_token(force_refresh=False) self.get_access_token(force_refresh=False)
if self.token is not None:
params["accessToken"] = self.token params["accessToken"] = self.token
def __do_sign(self, sign_str, sk): def __do_sign(self, sign_str: str, sk: str) -> str:
return ( return (
base64.encodebytes(hmac.new(sk.encode(), sign_str.encode(), digestmod=hashlib.sha1).digest()) base64.encodebytes(hmac.new(sk.encode(), sign_str.encode(), digestmod=hashlib.sha1).digest())
.decode() .decode()
.strip() .strip()
) )
def get_sign_str(self, group, tenant, ts): def get_sign_str(self, group: str, tenant: str, ts: str) -> str:
sign_str = "" sign_str = ""
if tenant: if tenant:
sign_str = tenant + "+" sign_str = tenant + "+"
@@ -63,7 +70,7 @@ class NacosHttpClient:
sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it. sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it.
return sign_str return sign_str
def get_access_token(self, force_refresh=False): def get_access_token(self, force_refresh: bool = False) -> str | None:
current_time = time.time() current_time = time.time()
if self.token and not force_refresh and self.token_expire_time > current_time: if self.token and not force_refresh and self.token_expire_time > current_time:
return self.token return self.token
@@ -77,6 +84,7 @@ class NacosHttpClient:
self.token = response_data.get("accessToken") self.token = response_data.get("accessToken")
self.token_ttl = response_data.get("tokenTtl", 18000) self.token_ttl = response_data.get("tokenTtl", 18000)
self.token_expire_time = current_time + self.token_ttl - 10 self.token_expire_time = current_time + self.token_ttl - 10
return self.token
except Exception: except Exception:
logger.exception("[get-access-token] exception occur") logger.exception("[get-access-token] exception occur")
raise raise

View File

@@ -1,4 +1,4 @@
def _parse_config(self, content: str) -> dict[str, str]: def parse_config(content: str) -> dict[str, str]:
config: dict[str, str] = {} config: dict[str, str] = {}
if not content: if not content:
return config return config

View File

@@ -1,5 +1,7 @@
{ {
"include": ["."], "include": [
"."
],
"exclude": [ "exclude": [
"tests/", "tests/",
"migrations/", "migrations/",
@@ -19,7 +21,6 @@
"events/", "events/",
"contexts/", "contexts/",
"constants/", "constants/",
"configs/",
"commands.py" "commands.py"
], ],
"typeCheckingMode": "strict", "typeCheckingMode": "strict",