feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

View File

@@ -54,12 +54,14 @@ def init_app(app: DifyApp):
from extensions.ext_database import db
engine = db.engine
# TODO: Fix the type error
# FIXME maybe its sqlalchemy issue
return {
"pid": os.getpid(),
"pool_size": engine.pool.size(),
"checked_in_connections": engine.pool.checkedin(),
"checked_out_connections": engine.pool.checkedout(),
"overflow_connections": engine.pool.overflow(),
"connection_timeout": engine.pool.timeout(),
"recycle_time": db.engine.pool._recycle,
"pool_size": engine.pool.size(), # type: ignore
"checked_in_connections": engine.pool.checkedin(), # type: ignore
"checked_out_connections": engine.pool.checkedout(), # type: ignore
"overflow_connections": engine.pool.overflow(), # type: ignore
"connection_timeout": engine.pool.timeout(), # type: ignore
"recycle_time": db.engine.pool._recycle, # type: ignore
}

View File

@@ -1,8 +1,8 @@
from datetime import timedelta
import pytz
from celery import Celery, Task
from celery.schedules import crontab
from celery import Celery, Task # type: ignore
from celery.schedules import crontab # type: ignore
from configs import dify_config
from dify_app import DifyApp
@@ -47,7 +47,7 @@ def init_app(app: DifyApp) -> Celery:
worker_log_format=dify_config.LOG_FORMAT,
worker_task_log_format=dify_config.LOG_FORMAT,
worker_hijack_root_logger=False,
timezone=pytz.timezone(dify_config.LOG_TZ),
timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"),
)
if dify_config.BROKER_USE_SSL:

View File

@@ -7,7 +7,7 @@ def is_enabled() -> bool:
def init_app(app: DifyApp):
from flask_compress import Compress
from flask_compress import Compress # type: ignore
compress = Compress()
compress.init_app(app)

View File

@@ -11,7 +11,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
log_handlers = []
log_handlers: list[logging.Handler] = []
log_file = dify_config.LOG_FILE
if log_file:
log_dir = os.path.dirname(log_file)
@@ -49,7 +49,8 @@ def init_app(app: DifyApp):
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
for handler in logging.root.handlers:
handler.formatter.converter = time_converter
if handler.formatter:
handler.formatter.converter = time_converter
def get_request_id():

View File

@@ -1,6 +1,6 @@
import json
import flask_login
import flask_login # type: ignore
from flask import Response, request
from flask_login import user_loaded_from_request, user_logged_in
from werkzeug.exceptions import Unauthorized

View File

@@ -26,7 +26,7 @@ class Mail:
match mail_type:
case "resend":
import resend
import resend # type: ignore
api_key = dify_config.RESEND_API_KEY
if not api_key:
@@ -48,9 +48,9 @@ class Mail:
self._client = SMTPClient(
server=dify_config.SMTP_SERVER,
port=dify_config.SMTP_PORT,
username=dify_config.SMTP_USERNAME,
password=dify_config.SMTP_PASSWORD,
_from=dify_config.MAIL_DEFAULT_SEND_FROM,
username=dify_config.SMTP_USERNAME or "",
password=dify_config.SMTP_PASSWORD or "",
_from=dify_config.MAIL_DEFAULT_SEND_FROM or "",
use_tls=dify_config.SMTP_USE_TLS,
opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS,
)

View File

@@ -2,7 +2,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
import flask_migrate
import flask_migrate # type: ignore
from extensions.ext_database import db

View File

@@ -6,4 +6,4 @@ def init_app(app: DifyApp):
if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED:
from werkzeug.middleware.proxy_fix import ProxyFix
app.wsgi_app = ProxyFix(app.wsgi_app)
app.wsgi_app = ProxyFix(app.wsgi_app) # type: ignore

View File

@@ -6,7 +6,7 @@ def init_app(app: DifyApp):
if dify_config.SENTRY_DSN:
import openai
import sentry_sdk
from langfuse import parse_error
from langfuse import parse_error # type: ignore
from sentry_sdk.integrations.celery import CeleryIntegration
from sentry_sdk.integrations.flask import FlaskIntegration
from werkzeug.exceptions import HTTPException

View File

@@ -1,6 +1,6 @@
import logging
from collections.abc import Callable, Generator
from typing import Union
from typing import Literal, Union, overload
from flask import Flask
@@ -79,6 +79,12 @@ class Storage:
logger.exception(f"Failed to save file {filename}")
raise e
@overload
def load(self, filename: str, /, *, stream: Literal[False] = False) -> bytes: ...
@overload
def load(self, filename: str, /, *, stream: Literal[True]) -> Generator: ...
def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]:
try:
if stream:

View File

@@ -1,7 +1,7 @@
import posixpath
from collections.abc import Generator
import oss2 as aliyun_s3
import oss2 as aliyun_s3 # type: ignore
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
@@ -33,7 +33,7 @@ class AliyunOssStorage(BaseStorage):
def load_once(self, filename: str) -> bytes:
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
data = obj.read()
data: bytes = obj.read()
return data
def load_stream(self, filename: str) -> Generator:
@@ -41,14 +41,14 @@ class AliyunOssStorage(BaseStorage):
while chunk := obj.read(4096):
yield chunk
def download(self, filename, target_filepath):
def download(self, filename: str, target_filepath):
self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath)
def exists(self, filename):
def exists(self, filename: str):
return self.client.object_exists(self.__wrapper_folder_filename(filename))
def delete(self, filename):
def delete(self, filename: str):
self.client.delete_object(self.__wrapper_folder_filename(filename))
def __wrapper_folder_filename(self, filename) -> str:
def __wrapper_folder_filename(self, filename: str) -> str:
return posixpath.join(self.folder, filename) if self.folder else filename

View File

@@ -1,9 +1,9 @@
import logging
from collections.abc import Generator
import boto3
from botocore.client import Config
from botocore.exceptions import ClientError
import boto3 # type: ignore
from botocore.client import Config # type: ignore
from botocore.exceptions import ClientError # type: ignore
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
@@ -53,7 +53,7 @@ class AwsS3Storage(BaseStorage):
def load_once(self, filename: str) -> bytes:
try:
data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
except ClientError as ex:
if ex.response["Error"]["Code"] == "NoSuchKey":
raise FileNotFoundError("File not found")

View File

@@ -27,7 +27,7 @@ class AzureBlobStorage(BaseStorage):
client = self._sync_client()
blob = client.get_container_client(container=self.bucket_name)
blob = blob.get_blob_client(blob=filename)
data = blob.download_blob().readall()
data: bytes = blob.download_blob().readall()
return data
def load_stream(self, filename: str) -> Generator:
@@ -63,11 +63,11 @@ class AzureBlobStorage(BaseStorage):
sas_token = cache_result.decode("utf-8")
else:
sas_token = generate_account_sas(
account_name=self.account_name,
account_key=self.account_key,
account_name=self.account_name or "",
account_key=self.account_key or "",
resource_types=ResourceTypes(service=True, container=True, object=True),
permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1),
)
redis_client.set(cache_key, sas_token, ex=3000)
return BlobServiceClient(account_url=self.account_url, credential=sas_token)
return BlobServiceClient(account_url=self.account_url or "", credential=sas_token)

View File

@@ -2,9 +2,9 @@ import base64
import hashlib
from collections.abc import Generator
from baidubce.auth.bce_credentials import BceCredentials
from baidubce.bce_client_configuration import BceClientConfiguration
from baidubce.services.bos.bos_client import BosClient
from baidubce.auth.bce_credentials import BceCredentials # type: ignore
from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore
from baidubce.services.bos.bos_client import BosClient # type: ignore
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
@@ -36,7 +36,8 @@ class BaiduObsStorage(BaseStorage):
def load_once(self, filename: str) -> bytes:
response = self.client.get_object(bucket_name=self.bucket_name, key=filename)
return response.data.read()
data: bytes = response.data.read()
return data
def load_stream(self, filename: str) -> Generator:
response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data

View File

@@ -3,7 +3,7 @@ import io
import json
from collections.abc import Generator
from google.cloud import storage as google_cloud_storage
from google.cloud import storage as google_cloud_storage # type: ignore
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
@@ -35,7 +35,7 @@ class GoogleCloudStorage(BaseStorage):
def load_once(self, filename: str) -> bytes:
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename)
data = blob.download_as_bytes()
data: bytes = blob.download_as_bytes()
return data
def load_stream(self, filename: str) -> Generator:

View File

@@ -1,6 +1,6 @@
from collections.abc import Generator
from obs import ObsClient
from obs import ObsClient # type: ignore
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
@@ -23,7 +23,7 @@ class HuaweiObsStorage(BaseStorage):
self.client.putObject(bucketName=self.bucket_name, objectKey=filename, content=data)
def load_once(self, filename: str) -> bytes:
data = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read()
data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read()
return data
def load_stream(self, filename: str) -> Generator:

View File

@@ -3,7 +3,7 @@ import os
from collections.abc import Generator
from pathlib import Path
import opendal
import opendal # type: ignore[import]
from dotenv import dotenv_values
from extensions.storage.base_storage import BaseStorage
@@ -18,7 +18,7 @@ def _get_opendal_kwargs(*, scheme: str, env_file_path: str = ".env", prefix: str
if key.startswith(config_prefix):
kwargs[key[len(config_prefix) :].lower()] = value
file_env_vars = dotenv_values(env_file_path)
file_env_vars: dict = dotenv_values(env_file_path) or {}
for key, value in file_env_vars.items():
if key.startswith(config_prefix) and key[len(config_prefix) :].lower() not in kwargs and value:
kwargs[key[len(config_prefix) :].lower()] = value
@@ -48,7 +48,7 @@ class OpenDALStorage(BaseStorage):
if not self.exists(filename):
raise FileNotFoundError("File not found")
content = self.op.read(path=filename)
content: bytes = self.op.read(path=filename)
logger.debug(f"file {filename} loaded")
return content
@@ -75,7 +75,7 @@ class OpenDALStorage(BaseStorage):
# error handler here when opendal python-binding has a exists method, we should use it
# more https://github.com/apache/opendal/blob/main/bindings/python/src/operator.rs
try:
res = self.op.stat(path=filename).mode.is_file()
res: bool = self.op.stat(path=filename).mode.is_file()
logger.debug(f"file {filename} checked")
return res
except Exception:

View File

@@ -1,7 +1,7 @@
from collections.abc import Generator
import boto3
from botocore.exceptions import ClientError
import boto3 # type: ignore
from botocore.exceptions import ClientError # type: ignore
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
@@ -27,7 +27,7 @@ class OracleOCIStorage(BaseStorage):
def load_once(self, filename: str) -> bytes:
try:
data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
except ClientError as ex:
if ex.response["Error"]["Code"] == "NoSuchKey":
raise FileNotFoundError("File not found")

View File

@@ -32,7 +32,7 @@ class SupabaseStorage(BaseStorage):
self.client.storage.from_(self.bucket_name).upload(filename, data)
def load_once(self, filename: str) -> bytes:
content = self.client.storage.from_(self.bucket_name).download(filename)
content: bytes = self.client.storage.from_(self.bucket_name).download(filename)
return content
def load_stream(self, filename: str) -> Generator:

View File

@@ -1,6 +1,6 @@
from collections.abc import Generator
from qcloud_cos import CosConfig, CosS3Client
from qcloud_cos import CosConfig, CosS3Client # type: ignore
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
@@ -25,7 +25,7 @@ class TencentCosStorage(BaseStorage):
self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename)
def load_once(self, filename: str) -> bytes:
data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read()
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read()
return data
def load_stream(self, filename: str) -> Generator:

View File

@@ -1,6 +1,6 @@
from collections.abc import Generator
import tos
import tos # type: ignore
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
@@ -24,6 +24,8 @@ class VolcengineTosStorage(BaseStorage):
def load_once(self, filename: str) -> bytes:
data = self.client.get_object(bucket=self.bucket_name, key=filename).read()
if not isinstance(data, bytes):
raise TypeError("Expected bytes, got {}".format(type(data).__name__))
return data
def load_stream(self, filename: str) -> Generator: