Initial commit
This commit is contained in:
23
api/extensions/ext_celery.py
Normal file
23
api/extensions/ext_celery.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from celery import Task, Celery
|
||||
from flask import Flask
|
||||
|
||||
|
||||
def init_app(app: Flask) -> Celery:
|
||||
class FlaskTask(Task):
|
||||
def __call__(self, *args: object, **kwargs: object) -> object:
|
||||
with app.app_context():
|
||||
return self.run(*args, **kwargs)
|
||||
|
||||
celery_app = Celery(
|
||||
app.name,
|
||||
task_cls=FlaskTask,
|
||||
broker=app.config["CELERY_BROKER_URL"],
|
||||
backend=app.config["CELERY_BACKEND"],
|
||||
task_ignore_result=True,
|
||||
)
|
||||
celery_app.conf.update(
|
||||
result_backend=app.config["CELERY_RESULT_BACKEND"],
|
||||
)
|
||||
celery_app.set_default()
|
||||
app.extensions["celery"] = celery_app
|
||||
return celery_app
|
7
api/extensions/ext_database.py
Normal file
7
api/extensions/ext_database.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
|
||||
db = SQLAlchemy()
|
||||
|
||||
|
||||
def init_app(app):
|
||||
db.init_app(app)
|
7
api/extensions/ext_login.py
Normal file
7
api/extensions/ext_login.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import flask_login
|
||||
|
||||
login_manager = flask_login.LoginManager()
|
||||
|
||||
|
||||
def init_app(app):
|
||||
login_manager.init_app(app)
|
5
api/extensions/ext_migrate.py
Normal file
5
api/extensions/ext_migrate.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import flask_migrate
|
||||
|
||||
|
||||
def init(app, db):
|
||||
flask_migrate.Migrate(app, db)
|
18
api/extensions/ext_redis.py
Normal file
18
api/extensions/ext_redis.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import redis
|
||||
|
||||
|
||||
redis_client = redis.Redis()
|
||||
|
||||
|
||||
def init_app(app):
|
||||
redis_client.connection_pool = redis.ConnectionPool(**{
|
||||
'host': app.config.get('REDIS_HOST', 'localhost'),
|
||||
'port': app.config.get('REDIS_PORT', 6379),
|
||||
'password': app.config.get('REDIS_PASSWORD', None),
|
||||
'db': app.config.get('REDIS_DB', 0),
|
||||
'encoding': 'utf-8',
|
||||
'encoding_errors': 'strict',
|
||||
'decode_responses': False
|
||||
})
|
||||
|
||||
app.extensions['redis'] = redis_client
|
20
api/extensions/ext_sentry.py
Normal file
20
api/extensions/ext_sentry.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import sentry_sdk
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sentry_sdk.integrations.flask import FlaskIntegration
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
|
||||
def init_app(app):
|
||||
if app.config.get('SENTRY_DSN'):
|
||||
sentry_sdk.init(
|
||||
dsn=app.config.get('SENTRY_DSN'),
|
||||
integrations=[
|
||||
FlaskIntegration(),
|
||||
CeleryIntegration()
|
||||
],
|
||||
ignore_errors=[HTTPException, ValueError],
|
||||
traces_sample_rate=app.config.get('SENTRY_TRACES_SAMPLE_RATE', 1.0),
|
||||
profiles_sample_rate=app.config.get('SENTRY_PROFILES_SAMPLE_RATE', 1.0),
|
||||
environment=app.config.get('DEPLOY_ENV'),
|
||||
release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}"
|
||||
)
|
168
api/extensions/ext_session.py
Normal file
168
api/extensions/ext_session.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import redis
|
||||
from flask import request
|
||||
from flask_session import Session, SqlAlchemySessionInterface, RedisSessionInterface
|
||||
from flask_session.sessions import total_seconds
|
||||
from itsdangerous import want_bytes
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
sess = Session()
|
||||
|
||||
|
||||
def init_app(app):
|
||||
sqlalchemy_session_interface = CustomSqlAlchemySessionInterface(
|
||||
app,
|
||||
db,
|
||||
app.config.get('SESSION_SQLALCHEMY_TABLE', 'sessions'),
|
||||
app.config.get('SESSION_KEY_PREFIX', 'session:'),
|
||||
app.config.get('SESSION_USE_SIGNER', False),
|
||||
app.config.get('SESSION_PERMANENT', True)
|
||||
)
|
||||
|
||||
session_type = app.config.get('SESSION_TYPE')
|
||||
if session_type == 'sqlalchemy':
|
||||
app.session_interface = sqlalchemy_session_interface
|
||||
elif session_type == 'redis':
|
||||
sess_redis_client = redis.Redis()
|
||||
sess_redis_client.connection_pool = redis.ConnectionPool(**{
|
||||
'host': app.config.get('SESSION_REDIS_HOST', 'localhost'),
|
||||
'port': app.config.get('SESSION_REDIS_PORT', 6379),
|
||||
'password': app.config.get('SESSION_REDIS_PASSWORD', None),
|
||||
'db': app.config.get('SESSION_REDIS_DB', 2),
|
||||
'encoding': 'utf-8',
|
||||
'encoding_errors': 'strict',
|
||||
'decode_responses': False
|
||||
})
|
||||
|
||||
app.extensions['session_redis'] = sess_redis_client
|
||||
|
||||
app.session_interface = CustomRedisSessionInterface(
|
||||
sess_redis_client,
|
||||
app.config.get('SESSION_KEY_PREFIX', 'session:'),
|
||||
app.config.get('SESSION_USE_SIGNER', False),
|
||||
app.config.get('SESSION_PERMANENT', True)
|
||||
)
|
||||
|
||||
|
||||
class CustomSqlAlchemySessionInterface(SqlAlchemySessionInterface):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app,
|
||||
db,
|
||||
table,
|
||||
key_prefix,
|
||||
use_signer=False,
|
||||
permanent=True,
|
||||
sequence=None,
|
||||
autodelete=False,
|
||||
):
|
||||
if db is None:
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
|
||||
db = SQLAlchemy(app)
|
||||
self.db = db
|
||||
self.key_prefix = key_prefix
|
||||
self.use_signer = use_signer
|
||||
self.permanent = permanent
|
||||
self.autodelete = autodelete
|
||||
self.sequence = sequence
|
||||
self.has_same_site_capability = hasattr(self, "get_cookie_samesite")
|
||||
|
||||
class Session(self.db.Model):
|
||||
__tablename__ = table
|
||||
|
||||
if sequence:
|
||||
id = self.db.Column( # noqa: A003, VNE003, A001
|
||||
self.db.Integer, self.db.Sequence(sequence), primary_key=True
|
||||
)
|
||||
else:
|
||||
id = self.db.Column( # noqa: A003, VNE003, A001
|
||||
self.db.Integer, primary_key=True
|
||||
)
|
||||
|
||||
session_id = self.db.Column(self.db.String(255), unique=True)
|
||||
data = self.db.Column(self.db.LargeBinary)
|
||||
expiry = self.db.Column(self.db.DateTime)
|
||||
|
||||
def __init__(self, session_id, data, expiry):
|
||||
self.session_id = session_id
|
||||
self.data = data
|
||||
self.expiry = expiry
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Session data {self.data}>"
|
||||
|
||||
self.sql_session_model = Session
|
||||
|
||||
def save_session(self, *args, **kwargs):
|
||||
if request.blueprint == 'service_api':
|
||||
return
|
||||
elif request.method == 'OPTIONS':
|
||||
return
|
||||
elif request.endpoint and request.endpoint == 'health':
|
||||
return
|
||||
return super().save_session(*args, **kwargs)
|
||||
|
||||
|
||||
class CustomRedisSessionInterface(RedisSessionInterface):
|
||||
|
||||
def save_session(self, app, session, response):
|
||||
if request.blueprint == 'service_api':
|
||||
return
|
||||
elif request.method == 'OPTIONS':
|
||||
return
|
||||
elif request.endpoint and request.endpoint == 'health':
|
||||
return
|
||||
|
||||
if not self.should_set_cookie(app, session):
|
||||
return
|
||||
domain = self.get_cookie_domain(app)
|
||||
path = self.get_cookie_path(app)
|
||||
if not session:
|
||||
if session.modified:
|
||||
self.redis.delete(self.key_prefix + session.sid)
|
||||
response.delete_cookie(
|
||||
app.config["SESSION_COOKIE_NAME"], domain=domain, path=path
|
||||
)
|
||||
return
|
||||
|
||||
# Modification case. There are upsides and downsides to
|
||||
# emitting a set-cookie header each request. The behavior
|
||||
# is controlled by the :meth:`should_set_cookie` method
|
||||
# which performs a quick check to figure out if the cookie
|
||||
# should be set or not. This is controlled by the
|
||||
# SESSION_REFRESH_EACH_REQUEST config flag as well as
|
||||
# the permanent flag on the session itself.
|
||||
# if not self.should_set_cookie(app, session):
|
||||
# return
|
||||
conditional_cookie_kwargs = {}
|
||||
httponly = self.get_cookie_httponly(app)
|
||||
secure = self.get_cookie_secure(app)
|
||||
if self.has_same_site_capability:
|
||||
conditional_cookie_kwargs["samesite"] = self.get_cookie_samesite(app)
|
||||
expires = self.get_expiration_time(app, session)
|
||||
|
||||
if session.permanent:
|
||||
value = self.serializer.dumps(dict(session))
|
||||
if value is not None:
|
||||
self.redis.setex(
|
||||
name=self.key_prefix + session.sid,
|
||||
value=value,
|
||||
time=total_seconds(app.permanent_session_lifetime),
|
||||
)
|
||||
|
||||
if self.use_signer:
|
||||
session_id = self._get_signer(app).sign(want_bytes(session.sid)).decode("utf-8")
|
||||
else:
|
||||
session_id = session.sid
|
||||
response.set_cookie(
|
||||
app.config["SESSION_COOKIE_NAME"],
|
||||
session_id,
|
||||
expires=expires,
|
||||
httponly=httponly,
|
||||
domain=domain,
|
||||
path=path,
|
||||
secure=secure,
|
||||
**conditional_cookie_kwargs,
|
||||
)
|
108
api/extensions/ext_storage.py
Normal file
108
api/extensions/ext_storage.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import closing
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
from flask import Flask
|
||||
|
||||
|
||||
class Storage:
|
||||
def __init__(self):
|
||||
self.storage_type = None
|
||||
self.bucket_name = None
|
||||
self.client = None
|
||||
self.folder = None
|
||||
|
||||
def init_app(self, app: Flask):
|
||||
self.storage_type = app.config.get('STORAGE_TYPE')
|
||||
if self.storage_type == 's3':
|
||||
self.bucket_name = app.config.get('S3_BUCKET_NAME')
|
||||
self.client = boto3.client(
|
||||
's3',
|
||||
aws_secret_access_key=app.config.get('S3_SECRET_KEY'),
|
||||
aws_access_key_id=app.config.get('S3_ACCESS_KEY'),
|
||||
endpoint_url=app.config.get('S3_ENDPOINT'),
|
||||
region_name=app.config.get('S3_REGION')
|
||||
)
|
||||
else:
|
||||
self.folder = app.config.get('STORAGE_LOCAL_PATH')
|
||||
if not os.path.isabs(self.folder):
|
||||
self.folder = os.path.join(app.root_path, self.folder)
|
||||
|
||||
def save(self, filename, data):
|
||||
if self.storage_type == 's3':
|
||||
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
else:
|
||||
filename = self.folder + '/' + filename
|
||||
|
||||
folder = os.path.dirname(filename)
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
with open(os.path.join(os.getcwd(), filename), "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
def load(self, filename):
|
||||
if self.storage_type == 's3':
|
||||
try:
|
||||
with closing(self.client) as client:
|
||||
data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read()
|
||||
except ClientError as ex:
|
||||
if ex.response['Error']['Code'] == 'NoSuchKey':
|
||||
raise FileNotFoundError("File not found")
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
else:
|
||||
filename = self.folder + '/' + filename
|
||||
|
||||
if not os.path.exists(filename):
|
||||
raise FileNotFoundError("File not found")
|
||||
|
||||
with open(filename, "rb") as f:
|
||||
data = f.read()
|
||||
|
||||
return data
|
||||
|
||||
def download(self, filename, target_filepath):
|
||||
if self.storage_type == 's3':
|
||||
with closing(self.client) as client:
|
||||
client.download_file(self.bucket_name, filename, target_filepath)
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
else:
|
||||
filename = self.folder + '/' + filename
|
||||
|
||||
if not os.path.exists(filename):
|
||||
raise FileNotFoundError("File not found")
|
||||
|
||||
shutil.copyfile(filename, target_filepath)
|
||||
|
||||
def exists(self, filename):
|
||||
if self.storage_type == 's3':
|
||||
with closing(self.client) as client:
|
||||
try:
|
||||
client.head_object(Bucket=self.bucket_name, Key=filename)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
else:
|
||||
filename = self.folder + '/' + filename
|
||||
|
||||
return os.path.exists(filename)
|
||||
|
||||
|
||||
storage = Storage()
|
||||
|
||||
|
||||
def init_app(app: Flask):
|
||||
storage.init_app(app)
|
7
api/extensions/ext_vector_store.py
Normal file
7
api/extensions/ext_vector_store.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from core.vector_store.vector_store import VectorStore
|
||||
|
||||
vector_store = VectorStore()
|
||||
|
||||
|
||||
def init_app(app):
|
||||
vector_store.init_app(app)
|
Reference in New Issue
Block a user