refactor: use dify_config to replace legacy usage of flask app's config (#9089)

This commit is contained in:
Bowen Liang
2024-10-22 11:01:32 +08:00
committed by GitHub
parent 8f670f31b8
commit 4d9160ca9f
27 changed files with 221 additions and 207 deletions

View File

@@ -3,6 +3,8 @@ from datetime import timedelta
from celery import Celery, Task
from flask import Flask
from configs import dify_config
def init_app(app: Flask) -> Celery:
class FlaskTask(Task):
@@ -12,19 +14,19 @@ def init_app(app: Flask) -> Celery:
broker_transport_options = {}
if app.config.get("CELERY_USE_SENTINEL"):
if dify_config.CELERY_USE_SENTINEL:
broker_transport_options = {
"master_name": app.config.get("CELERY_SENTINEL_MASTER_NAME"),
"master_name": dify_config.CELERY_SENTINEL_MASTER_NAME,
"sentinel_kwargs": {
"socket_timeout": app.config.get("CELERY_SENTINEL_SOCKET_TIMEOUT", 0.1),
"socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
},
}
celery_app = Celery(
app.name,
task_cls=FlaskTask,
broker=app.config.get("CELERY_BROKER_URL"),
backend=app.config.get("CELERY_BACKEND"),
broker=dify_config.CELERY_BROKER_URL,
backend=dify_config.CELERY_BACKEND,
task_ignore_result=True,
)
@@ -37,12 +39,12 @@ def init_app(app: Flask) -> Celery:
}
celery_app.conf.update(
result_backend=app.config.get("CELERY_RESULT_BACKEND"),
result_backend=dify_config.CELERY_RESULT_BACKEND,
broker_transport_options=broker_transport_options,
broker_connection_retry_on_startup=True,
)
if app.config.get("BROKER_USE_SSL"):
if dify_config.BROKER_USE_SSL:
celery_app.conf.update(
broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration
)
@@ -54,7 +56,7 @@ def init_app(app: Flask) -> Celery:
"schedule.clean_embedding_cache_task",
"schedule.clean_unused_datasets_task",
]
day = app.config.get("CELERY_BEAT_SCHEDULER_TIME")
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
beat_schedule = {
"clean_embedding_cache_task": {
"task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task",

View File

@@ -1,8 +1,10 @@
from flask import Flask
from configs import dify_config
def init_app(app: Flask):
if app.config.get("API_COMPRESSION_ENABLED"):
if dify_config.API_COMPRESSION_ENABLED:
from flask_compress import Compress
app.config["COMPRESS_MIMETYPES"] = [

View File

@@ -5,10 +5,12 @@ from logging.handlers import RotatingFileHandler
from flask import Flask
from configs import dify_config
def init_app(app: Flask):
log_handlers = None
log_file = app.config.get("LOG_FILE")
log_file = dify_config.LOG_FILE
if log_file:
log_dir = os.path.dirname(log_file)
os.makedirs(log_dir, exist_ok=True)
@@ -22,13 +24,13 @@ def init_app(app: Flask):
]
logging.basicConfig(
level=app.config.get("LOG_LEVEL"),
format=app.config.get("LOG_FORMAT"),
datefmt=app.config.get("LOG_DATEFORMAT"),
level=dify_config.LOG_LEVEL,
format=dify_config.LOG_FORMAT,
datefmt=dify_config.LOG_DATEFORMAT,
handlers=log_handlers,
force=True,
)
log_tz = app.config.get("LOG_TZ")
log_tz = dify_config.LOG_TZ
if log_tz:
from datetime import datetime

View File

@@ -4,6 +4,8 @@ from typing import Optional
import resend
from flask import Flask
from configs import dify_config
class Mail:
def __init__(self):
@@ -14,41 +16,44 @@ class Mail:
return self._client is not None
def init_app(self, app: Flask):
if app.config.get("MAIL_TYPE"):
if app.config.get("MAIL_DEFAULT_SEND_FROM"):
self._default_send_from = app.config.get("MAIL_DEFAULT_SEND_FROM")
mail_type = dify_config.MAIL_TYPE
if not mail_type:
logging.warning("MAIL_TYPE is not set")
return
if app.config.get("MAIL_TYPE") == "resend":
api_key = app.config.get("RESEND_API_KEY")
if dify_config.MAIL_DEFAULT_SEND_FROM:
self._default_send_from = dify_config.MAIL_DEFAULT_SEND_FROM
match mail_type:
case "resend":
api_key = dify_config.RESEND_API_KEY
if not api_key:
raise ValueError("RESEND_API_KEY is not set")
api_url = app.config.get("RESEND_API_URL")
api_url = dify_config.RESEND_API_URL
if api_url:
resend.api_url = api_url
resend.api_key = api_key
self._client = resend.Emails
elif app.config.get("MAIL_TYPE") == "smtp":
case "smtp":
from libs.smtp import SMTPClient
if not app.config.get("SMTP_SERVER") or not app.config.get("SMTP_PORT"):
if not dify_config.SMTP_SERVER or not dify_config.SMTP_PORT:
raise ValueError("SMTP_SERVER and SMTP_PORT are required for smtp mail type")
if not app.config.get("SMTP_USE_TLS") and app.config.get("SMTP_OPPORTUNISTIC_TLS"):
if not dify_config.SMTP_USE_TLS and dify_config.SMTP_OPPORTUNISTIC_TLS:
raise ValueError("SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS")
self._client = SMTPClient(
server=app.config.get("SMTP_SERVER"),
port=app.config.get("SMTP_PORT"),
username=app.config.get("SMTP_USERNAME"),
password=app.config.get("SMTP_PASSWORD"),
_from=app.config.get("MAIL_DEFAULT_SEND_FROM"),
use_tls=app.config.get("SMTP_USE_TLS"),
opportunistic_tls=app.config.get("SMTP_OPPORTUNISTIC_TLS"),
server=dify_config.SMTP_SERVER,
port=dify_config.SMTP_PORT,
username=dify_config.SMTP_USERNAME,
password=dify_config.SMTP_PASSWORD,
_from=dify_config.MAIL_DEFAULT_SEND_FROM,
use_tls=dify_config.SMTP_USE_TLS,
opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS,
)
else:
raise ValueError("Unsupported mail type {}".format(app.config.get("MAIL_TYPE")))
else:
logging.warning("MAIL_TYPE is not set")
case _:
raise ValueError("Unsupported mail type {}".format(mail_type))
def send(self, to: str, subject: str, html: str, from_: Optional[str] = None):
if not self._client:

View File

@@ -2,6 +2,8 @@ import redis
from redis.connection import Connection, SSLConnection
from redis.sentinel import Sentinel
from configs import dify_config
class RedisClientWrapper(redis.Redis):
"""
@@ -43,37 +45,37 @@ redis_client = RedisClientWrapper()
def init_app(app):
global redis_client
connection_class = Connection
if app.config.get("REDIS_USE_SSL"):
if dify_config.REDIS_USE_SSL:
connection_class = SSLConnection
redis_params = {
"username": app.config.get("REDIS_USERNAME"),
"password": app.config.get("REDIS_PASSWORD"),
"db": app.config.get("REDIS_DB"),
"username": dify_config.REDIS_USERNAME,
"password": dify_config.REDIS_PASSWORD,
"db": dify_config.REDIS_DB,
"encoding": "utf-8",
"encoding_errors": "strict",
"decode_responses": False,
}
if app.config.get("REDIS_USE_SENTINEL"):
if dify_config.REDIS_USE_SENTINEL:
sentinel_hosts = [
(node.split(":")[0], int(node.split(":")[1])) for node in app.config.get("REDIS_SENTINELS").split(",")
(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")
]
sentinel = Sentinel(
sentinel_hosts,
sentinel_kwargs={
"socket_timeout": app.config.get("REDIS_SENTINEL_SOCKET_TIMEOUT", 0.1),
"username": app.config.get("REDIS_SENTINEL_USERNAME"),
"password": app.config.get("REDIS_SENTINEL_PASSWORD"),
"socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT,
"username": dify_config.REDIS_SENTINEL_USERNAME,
"password": dify_config.REDIS_SENTINEL_PASSWORD,
},
)
master = sentinel.master_for(app.config.get("REDIS_SENTINEL_SERVICE_NAME"), **redis_params)
master = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
redis_client.initialize(master)
else:
redis_params.update(
{
"host": app.config.get("REDIS_HOST"),
"port": app.config.get("REDIS_PORT"),
"host": dify_config.REDIS_HOST,
"port": dify_config.REDIS_PORT,
"connection_class": connection_class,
}
)

View File

@@ -5,6 +5,7 @@ from sentry_sdk.integrations.celery import CeleryIntegration
from sentry_sdk.integrations.flask import FlaskIntegration
from werkzeug.exceptions import HTTPException
from configs import dify_config
from core.model_runtime.errors.invoke import InvokeRateLimitError
@@ -18,9 +19,9 @@ def before_send(event, hint):
def init_app(app):
if app.config.get("SENTRY_DSN"):
if dify_config.SENTRY_DSN:
sentry_sdk.init(
dsn=app.config.get("SENTRY_DSN"),
dsn=dify_config.SENTRY_DSN,
integrations=[FlaskIntegration(), CeleryIntegration()],
ignore_errors=[
HTTPException,
@@ -29,9 +30,9 @@ def init_app(app):
InvokeRateLimitError,
parse_error.defaultErrorResponse,
],
traces_sample_rate=app.config.get("SENTRY_TRACES_SAMPLE_RATE", 1.0),
profiles_sample_rate=app.config.get("SENTRY_PROFILES_SAMPLE_RATE", 1.0),
environment=app.config.get("DEPLOY_ENV"),
release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}",
traces_sample_rate=dify_config.SENTRY_TRACES_SAMPLE_RATE,
profiles_sample_rate=dify_config.SENTRY_PROFILES_SAMPLE_RATE,
environment=dify_config.DEPLOY_ENV,
release=f"dify-{dify_config.CURRENT_VERSION}-{dify_config.COMMIT_SHA}",
before_send=before_send,
)

View File

@@ -15,7 +15,7 @@ class Storage:
def init_app(self, app: Flask):
storage_factory = self.get_storage_factory(dify_config.STORAGE_TYPE)
self.storage_runner = storage_factory(app=app)
self.storage_runner = storage_factory()
@staticmethod
def get_storage_factory(storage_type: str) -> type[BaseStorage]:

View File

@@ -1,29 +1,27 @@
from collections.abc import Generator
import oss2 as aliyun_s3
from flask import Flask
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
class AliyunOssStorage(BaseStorage):
"""Implementation for Aliyun OSS storage."""
def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get("ALIYUN_OSS_BUCKET_NAME")
self.folder = app.config.get("ALIYUN_OSS_PATH")
def __init__(self):
super().__init__()
self.bucket_name = dify_config.ALIYUN_OSS_BUCKET_NAME
self.folder = dify_config.ALIYUN_OSS_PATH
oss_auth_method = aliyun_s3.Auth
region = None
if app_config.get("ALIYUN_OSS_AUTH_VERSION") == "v4":
if dify_config.ALIYUN_OSS_AUTH_VERSION == "v4":
oss_auth_method = aliyun_s3.AuthV4
region = app_config.get("ALIYUN_OSS_REGION")
oss_auth = oss_auth_method(app_config.get("ALIYUN_OSS_ACCESS_KEY"), app_config.get("ALIYUN_OSS_SECRET_KEY"))
region = dify_config.ALIYUN_OSS_REGION
oss_auth = oss_auth_method(dify_config.ALIYUN_OSS_ACCESS_KEY, dify_config.ALIYUN_OSS_SECRET_KEY)
self.client = aliyun_s3.Bucket(
oss_auth,
app_config.get("ALIYUN_OSS_ENDPOINT"),
dify_config.ALIYUN_OSS_ENDPOINT,
self.bucket_name,
connect_timeout=30,
region=region,

View File

@@ -2,8 +2,8 @@ from collections.abc import Generator
from datetime import datetime, timedelta, timezone
from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
from flask import Flask
from configs import dify_config
from extensions.ext_redis import redis_client
from extensions.storage.base_storage import BaseStorage
@@ -11,13 +11,12 @@ from extensions.storage.base_storage import BaseStorage
class AzureBlobStorage(BaseStorage):
"""Implementation for Azure Blob storage."""
def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get("AZURE_BLOB_CONTAINER_NAME")
self.account_url = app_config.get("AZURE_BLOB_ACCOUNT_URL")
self.account_name = app_config.get("AZURE_BLOB_ACCOUNT_NAME")
self.account_key = app_config.get("AZURE_BLOB_ACCOUNT_KEY")
def __init__(self):
super().__init__()
self.bucket_name = dify_config.AZURE_BLOB_CONTAINER_NAME
self.account_url = dify_config.AZURE_BLOB_ACCOUNT_URL
self.account_name = dify_config.AZURE_BLOB_ACCOUNT_NAME
self.account_key = dify_config.AZURE_BLOB_ACCOUNT_KEY
def save(self, filename, data):
client = self._sync_client()

View File

@@ -5,24 +5,23 @@ from collections.abc import Generator
from baidubce.auth.bce_credentials import BceCredentials
from baidubce.bce_client_configuration import BceClientConfiguration
from baidubce.services.bos.bos_client import BosClient
from flask import Flask
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
class BaiduObsStorage(BaseStorage):
"""Implementation for Baidu OBS storage."""
def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get("BAIDU_OBS_BUCKET_NAME")
def __init__(self):
super().__init__()
self.bucket_name = dify_config.BAIDU_OBS_BUCKET_NAME
client_config = BceClientConfiguration(
credentials=BceCredentials(
access_key_id=app_config.get("BAIDU_OBS_ACCESS_KEY"),
secret_access_key=app_config.get("BAIDU_OBS_SECRET_KEY"),
access_key_id=dify_config.BAIDU_OBS_ACCESS_KEY,
secret_access_key=dify_config.BAIDU_OBS_SECRET_KEY,
),
endpoint=app_config.get("BAIDU_OBS_ENDPOINT"),
endpoint=dify_config.BAIDU_OBS_ENDPOINT,
)
self.client = BosClient(config=client_config)

View File

@@ -3,16 +3,12 @@
from abc import ABC, abstractmethod
from collections.abc import Generator
from flask import Flask
class BaseStorage(ABC):
"""Interface for file storage."""
app = None
def __init__(self, app: Flask):
self.app = app
def __init__(self): # noqa: B027
pass
@abstractmethod
def save(self, filename, data):

View File

@@ -3,20 +3,20 @@ import io
import json
from collections.abc import Generator
from flask import Flask
from google.cloud import storage as google_cloud_storage
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
class GoogleCloudStorage(BaseStorage):
"""Implementation for Google Cloud storage."""
def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get("GOOGLE_STORAGE_BUCKET_NAME")
service_account_json_str = app_config.get("GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64")
def __init__(self):
super().__init__()
self.bucket_name = dify_config.GOOGLE_STORAGE_BUCKET_NAME
service_account_json_str = dify_config.GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64
# if service_account_json_str is empty, use Application Default Credentials
if service_account_json_str:
service_account_json = base64.b64decode(service_account_json_str).decode("utf-8")

View File

@@ -1,22 +1,22 @@
from collections.abc import Generator
from flask import Flask
from obs import ObsClient
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
class HuaweiObsStorage(BaseStorage):
"""Implementation for Huawei OBS storage."""
def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get("HUAWEI_OBS_BUCKET_NAME")
def __init__(self):
super().__init__()
self.bucket_name = dify_config.HUAWEI_OBS_BUCKET_NAME
self.client = ObsClient(
access_key_id=app_config.get("HUAWEI_OBS_ACCESS_KEY"),
secret_access_key=app_config.get("HUAWEI_OBS_SECRET_KEY"),
server=app_config.get("HUAWEI_OBS_SERVER"),
access_key_id=dify_config.HUAWEI_OBS_ACCESS_KEY,
secret_access_key=dify_config.HUAWEI_OBS_SECRET_KEY,
server=dify_config.HUAWEI_OBS_SERVER,
)
def save(self, filename, data):

View File

@@ -3,19 +3,20 @@ import shutil
from collections.abc import Generator
from pathlib import Path
from flask import Flask
from flask import current_app
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
class LocalFsStorage(BaseStorage):
"""Implementation for local filesystem storage."""
def __init__(self, app: Flask):
super().__init__(app)
folder = self.app.config.get("STORAGE_LOCAL_PATH")
def __init__(self):
super().__init__()
folder = dify_config.STORAGE_LOCAL_PATH
if not os.path.isabs(folder):
folder = os.path.join(app.root_path, folder)
folder = os.path.join(current_app.root_path, folder)
self.folder = folder
def save(self, filename, data):

View File

@@ -2,24 +2,24 @@ from collections.abc import Generator
import boto3
from botocore.exceptions import ClientError
from flask import Flask
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
class OracleOCIStorage(BaseStorage):
"""Implementation for Oracle OCI storage."""
def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get("OCI_BUCKET_NAME")
def __init__(self):
super().__init__()
self.bucket_name = dify_config.OCI_BUCKET_NAME
self.client = boto3.client(
"s3",
aws_secret_access_key=app_config.get("OCI_SECRET_KEY"),
aws_access_key_id=app_config.get("OCI_ACCESS_KEY"),
endpoint_url=app_config.get("OCI_ENDPOINT"),
region_name=app_config.get("OCI_REGION"),
aws_secret_access_key=dify_config.OCI_SECRET_KEY,
aws_access_key_id=dify_config.OCI_ACCESS_KEY,
endpoint_url=dify_config.OCI_ENDPOINT,
region_name=dify_config.OCI_REGION,
)
def save(self, filename, data):

View File

@@ -1,23 +1,23 @@
from collections.abc import Generator
from flask import Flask
from qcloud_cos import CosConfig, CosS3Client
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
class TencentCosStorage(BaseStorage):
"""Implementation for Tencent Cloud COS storage."""
def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get("TENCENT_COS_BUCKET_NAME")
def __init__(self):
super().__init__()
self.bucket_name = dify_config.TENCENT_COS_BUCKET_NAME
config = CosConfig(
Region=app_config.get("TENCENT_COS_REGION"),
SecretId=app_config.get("TENCENT_COS_SECRET_ID"),
SecretKey=app_config.get("TENCENT_COS_SECRET_KEY"),
Scheme=app_config.get("TENCENT_COS_SCHEME"),
Region=dify_config.TENCENT_COS_REGION,
SecretId=dify_config.TENCENT_COS_SECRET_ID,
SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
Scheme=dify_config.TENCENT_COS_SCHEME,
)
self.client = CosS3Client(config)

View File

@@ -1,23 +1,22 @@
from collections.abc import Generator
import tos
from flask import Flask
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
class VolcengineTosStorage(BaseStorage):
"""Implementation for Volcengine TOS storage."""
def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get("VOLCENGINE_TOS_BUCKET_NAME")
def __init__(self):
super().__init__()
self.bucket_name = dify_config.VOLCENGINE_TOS_BUCKET_NAME
self.client = tos.TosClientV2(
ak=app_config.get("VOLCENGINE_TOS_ACCESS_KEY"),
sk=app_config.get("VOLCENGINE_TOS_SECRET_KEY"),
endpoint=app_config.get("VOLCENGINE_TOS_ENDPOINT"),
region=app_config.get("VOLCENGINE_TOS_REGION"),
ak=dify_config.VOLCENGINE_TOS_ACCESS_KEY,
sk=dify_config.VOLCENGINE_TOS_SECRET_KEY,
endpoint=dify_config.VOLCENGINE_TOS_ENDPOINT,
region=dify_config.VOLCENGINE_TOS_REGION,
)
def save(self, filename, data):