feat: mypy for all type check (#10921)
This commit is contained in:
0
api/extensions/__init__.py
Normal file
0
api/extensions/__init__.py
Normal file
@@ -54,12 +54,14 @@ def init_app(app: DifyApp):
|
||||
from extensions.ext_database import db
|
||||
|
||||
engine = db.engine
|
||||
# TODO: Fix the type error
|
||||
# FIXME maybe its sqlalchemy issue
|
||||
return {
|
||||
"pid": os.getpid(),
|
||||
"pool_size": engine.pool.size(),
|
||||
"checked_in_connections": engine.pool.checkedin(),
|
||||
"checked_out_connections": engine.pool.checkedout(),
|
||||
"overflow_connections": engine.pool.overflow(),
|
||||
"connection_timeout": engine.pool.timeout(),
|
||||
"recycle_time": db.engine.pool._recycle,
|
||||
"pool_size": engine.pool.size(), # type: ignore
|
||||
"checked_in_connections": engine.pool.checkedin(), # type: ignore
|
||||
"checked_out_connections": engine.pool.checkedout(), # type: ignore
|
||||
"overflow_connections": engine.pool.overflow(), # type: ignore
|
||||
"connection_timeout": engine.pool.timeout(), # type: ignore
|
||||
"recycle_time": db.engine.pool._recycle, # type: ignore
|
||||
}
|
||||
|
@@ -1,8 +1,8 @@
|
||||
from datetime import timedelta
|
||||
|
||||
import pytz
|
||||
from celery import Celery, Task
|
||||
from celery.schedules import crontab
|
||||
from celery import Celery, Task # type: ignore
|
||||
from celery.schedules import crontab # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
@@ -47,7 +47,7 @@ def init_app(app: DifyApp) -> Celery:
|
||||
worker_log_format=dify_config.LOG_FORMAT,
|
||||
worker_task_log_format=dify_config.LOG_FORMAT,
|
||||
worker_hijack_root_logger=False,
|
||||
timezone=pytz.timezone(dify_config.LOG_TZ),
|
||||
timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"),
|
||||
)
|
||||
|
||||
if dify_config.BROKER_USE_SSL:
|
||||
|
@@ -7,7 +7,7 @@ def is_enabled() -> bool:
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
from flask_compress import Compress
|
||||
from flask_compress import Compress # type: ignore
|
||||
|
||||
compress = Compress()
|
||||
compress.init_app(app)
|
||||
|
@@ -11,7 +11,7 @@ from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
log_handlers = []
|
||||
log_handlers: list[logging.Handler] = []
|
||||
log_file = dify_config.LOG_FILE
|
||||
if log_file:
|
||||
log_dir = os.path.dirname(log_file)
|
||||
@@ -49,7 +49,8 @@ def init_app(app: DifyApp):
|
||||
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
|
||||
|
||||
for handler in logging.root.handlers:
|
||||
handler.formatter.converter = time_converter
|
||||
if handler.formatter:
|
||||
handler.formatter.converter = time_converter
|
||||
|
||||
|
||||
def get_request_id():
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
|
||||
import flask_login
|
||||
import flask_login # type: ignore
|
||||
from flask import Response, request
|
||||
from flask_login import user_loaded_from_request, user_logged_in
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
@@ -26,7 +26,7 @@ class Mail:
|
||||
|
||||
match mail_type:
|
||||
case "resend":
|
||||
import resend
|
||||
import resend # type: ignore
|
||||
|
||||
api_key = dify_config.RESEND_API_KEY
|
||||
if not api_key:
|
||||
@@ -48,9 +48,9 @@ class Mail:
|
||||
self._client = SMTPClient(
|
||||
server=dify_config.SMTP_SERVER,
|
||||
port=dify_config.SMTP_PORT,
|
||||
username=dify_config.SMTP_USERNAME,
|
||||
password=dify_config.SMTP_PASSWORD,
|
||||
_from=dify_config.MAIL_DEFAULT_SEND_FROM,
|
||||
username=dify_config.SMTP_USERNAME or "",
|
||||
password=dify_config.SMTP_PASSWORD or "",
|
||||
_from=dify_config.MAIL_DEFAULT_SEND_FROM or "",
|
||||
use_tls=dify_config.SMTP_USE_TLS,
|
||||
opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS,
|
||||
)
|
||||
|
@@ -2,7 +2,7 @@ from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
import flask_migrate
|
||||
import flask_migrate # type: ignore
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
@@ -6,4 +6,4 @@ def init_app(app: DifyApp):
|
||||
if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED:
|
||||
from werkzeug.middleware.proxy_fix import ProxyFix
|
||||
|
||||
app.wsgi_app = ProxyFix(app.wsgi_app)
|
||||
app.wsgi_app = ProxyFix(app.wsgi_app) # type: ignore
|
||||
|
@@ -6,7 +6,7 @@ def init_app(app: DifyApp):
|
||||
if dify_config.SENTRY_DSN:
|
||||
import openai
|
||||
import sentry_sdk
|
||||
from langfuse import parse_error
|
||||
from langfuse import parse_error # type: ignore
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sentry_sdk.integrations.flask import FlaskIntegration
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import Union
|
||||
from typing import Literal, Union, overload
|
||||
|
||||
from flask import Flask
|
||||
|
||||
@@ -79,6 +79,12 @@ class Storage:
|
||||
logger.exception(f"Failed to save file {filename}")
|
||||
raise e
|
||||
|
||||
@overload
|
||||
def load(self, filename: str, /, *, stream: Literal[False] = False) -> bytes: ...
|
||||
|
||||
@overload
|
||||
def load(self, filename: str, /, *, stream: Literal[True]) -> Generator: ...
|
||||
|
||||
def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]:
|
||||
try:
|
||||
if stream:
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import posixpath
|
||||
from collections.abc import Generator
|
||||
|
||||
import oss2 as aliyun_s3
|
||||
import oss2 as aliyun_s3 # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
@@ -33,7 +33,7 @@ class AliyunOssStorage(BaseStorage):
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
|
||||
data = obj.read()
|
||||
data: bytes = obj.read()
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
@@ -41,14 +41,14 @@ class AliyunOssStorage(BaseStorage):
|
||||
while chunk := obj.read(4096):
|
||||
yield chunk
|
||||
|
||||
def download(self, filename, target_filepath):
|
||||
def download(self, filename: str, target_filepath):
|
||||
self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath)
|
||||
|
||||
def exists(self, filename):
|
||||
def exists(self, filename: str):
|
||||
return self.client.object_exists(self.__wrapper_folder_filename(filename))
|
||||
|
||||
def delete(self, filename):
|
||||
def delete(self, filename: str):
|
||||
self.client.delete_object(self.__wrapper_folder_filename(filename))
|
||||
|
||||
def __wrapper_folder_filename(self, filename) -> str:
|
||||
def __wrapper_folder_filename(self, filename: str) -> str:
|
||||
return posixpath.join(self.folder, filename) if self.folder else filename
|
||||
|
@@ -1,9 +1,9 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
import boto3
|
||||
from botocore.client import Config
|
||||
from botocore.exceptions import ClientError
|
||||
import boto3 # type: ignore
|
||||
from botocore.client import Config # type: ignore
|
||||
from botocore.exceptions import ClientError # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
@@ -53,7 +53,7 @@ class AwsS3Storage(BaseStorage):
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
try:
|
||||
data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
|
||||
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
|
||||
except ClientError as ex:
|
||||
if ex.response["Error"]["Code"] == "NoSuchKey":
|
||||
raise FileNotFoundError("File not found")
|
||||
|
@@ -27,7 +27,7 @@ class AzureBlobStorage(BaseStorage):
|
||||
client = self._sync_client()
|
||||
blob = client.get_container_client(container=self.bucket_name)
|
||||
blob = blob.get_blob_client(blob=filename)
|
||||
data = blob.download_blob().readall()
|
||||
data: bytes = blob.download_blob().readall()
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
@@ -63,11 +63,11 @@ class AzureBlobStorage(BaseStorage):
|
||||
sas_token = cache_result.decode("utf-8")
|
||||
else:
|
||||
sas_token = generate_account_sas(
|
||||
account_name=self.account_name,
|
||||
account_key=self.account_key,
|
||||
account_name=self.account_name or "",
|
||||
account_key=self.account_key or "",
|
||||
resource_types=ResourceTypes(service=True, container=True, object=True),
|
||||
permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
|
||||
expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1),
|
||||
)
|
||||
redis_client.set(cache_key, sas_token, ex=3000)
|
||||
return BlobServiceClient(account_url=self.account_url, credential=sas_token)
|
||||
return BlobServiceClient(account_url=self.account_url or "", credential=sas_token)
|
||||
|
@@ -2,9 +2,9 @@ import base64
|
||||
import hashlib
|
||||
from collections.abc import Generator
|
||||
|
||||
from baidubce.auth.bce_credentials import BceCredentials
|
||||
from baidubce.bce_client_configuration import BceClientConfiguration
|
||||
from baidubce.services.bos.bos_client import BosClient
|
||||
from baidubce.auth.bce_credentials import BceCredentials # type: ignore
|
||||
from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore
|
||||
from baidubce.services.bos.bos_client import BosClient # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
@@ -36,7 +36,8 @@ class BaiduObsStorage(BaseStorage):
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
response = self.client.get_object(bucket_name=self.bucket_name, key=filename)
|
||||
return response.data.read()
|
||||
data: bytes = response.data.read()
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data
|
||||
|
@@ -3,7 +3,7 @@ import io
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from google.cloud import storage as google_cloud_storage
|
||||
from google.cloud import storage as google_cloud_storage # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
@@ -35,7 +35,7 @@ class GoogleCloudStorage(BaseStorage):
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blob = bucket.get_blob(filename)
|
||||
data = blob.download_as_bytes()
|
||||
data: bytes = blob.download_as_bytes()
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from obs import ObsClient
|
||||
from obs import ObsClient # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
@@ -23,7 +23,7 @@ class HuaweiObsStorage(BaseStorage):
|
||||
self.client.putObject(bucketName=self.bucket_name, objectKey=filename, content=data)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
data = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read()
|
||||
data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read()
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
|
@@ -3,7 +3,7 @@ import os
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
|
||||
import opendal
|
||||
import opendal # type: ignore[import]
|
||||
from dotenv import dotenv_values
|
||||
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
@@ -18,7 +18,7 @@ def _get_opendal_kwargs(*, scheme: str, env_file_path: str = ".env", prefix: str
|
||||
if key.startswith(config_prefix):
|
||||
kwargs[key[len(config_prefix) :].lower()] = value
|
||||
|
||||
file_env_vars = dotenv_values(env_file_path)
|
||||
file_env_vars: dict = dotenv_values(env_file_path) or {}
|
||||
for key, value in file_env_vars.items():
|
||||
if key.startswith(config_prefix) and key[len(config_prefix) :].lower() not in kwargs and value:
|
||||
kwargs[key[len(config_prefix) :].lower()] = value
|
||||
@@ -48,7 +48,7 @@ class OpenDALStorage(BaseStorage):
|
||||
if not self.exists(filename):
|
||||
raise FileNotFoundError("File not found")
|
||||
|
||||
content = self.op.read(path=filename)
|
||||
content: bytes = self.op.read(path=filename)
|
||||
logger.debug(f"file {filename} loaded")
|
||||
return content
|
||||
|
||||
@@ -75,7 +75,7 @@ class OpenDALStorage(BaseStorage):
|
||||
# error handler here when opendal python-binding has a exists method, we should use it
|
||||
# more https://github.com/apache/opendal/blob/main/bindings/python/src/operator.rs
|
||||
try:
|
||||
res = self.op.stat(path=filename).mode.is_file()
|
||||
res: bool = self.op.stat(path=filename).mode.is_file()
|
||||
logger.debug(f"file {filename} checked")
|
||||
return res
|
||||
except Exception:
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
import boto3 # type: ignore
|
||||
from botocore.exceptions import ClientError # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
@@ -27,7 +27,7 @@ class OracleOCIStorage(BaseStorage):
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
try:
|
||||
data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
|
||||
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
|
||||
except ClientError as ex:
|
||||
if ex.response["Error"]["Code"] == "NoSuchKey":
|
||||
raise FileNotFoundError("File not found")
|
||||
|
@@ -32,7 +32,7 @@ class SupabaseStorage(BaseStorage):
|
||||
self.client.storage.from_(self.bucket_name).upload(filename, data)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
content = self.client.storage.from_(self.bucket_name).download(filename)
|
||||
content: bytes = self.client.storage.from_(self.bucket_name).download(filename)
|
||||
return content
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from qcloud_cos import CosConfig, CosS3Client
|
||||
from qcloud_cos import CosConfig, CosS3Client # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
@@ -25,7 +25,7 @@ class TencentCosStorage(BaseStorage):
|
||||
self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read()
|
||||
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read()
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
import tos
|
||||
import tos # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
@@ -24,6 +24,8 @@ class VolcengineTosStorage(BaseStorage):
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
data = self.client.get_object(bucket=self.bucket_name, key=filename).read()
|
||||
if not isinstance(data, bytes):
|
||||
raise TypeError("Expected bytes, got {}".format(type(data).__name__))
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
|
Reference in New Issue
Block a user