add more typing (#24949)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
@@ -6,6 +8,8 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
@@ -14,9 +18,9 @@ from extensions.ext_database import db
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
|
||||
|
||||
def admin_required(view):
|
||||
def admin_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import cast
|
||||
from typing import Concatenate, ParamSpec, TypeVar, cast
|
||||
|
||||
import flask_login
|
||||
from flask import jsonify, request
|
||||
@@ -15,10 +16,14 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,
|
||||
|
||||
from .. import api
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
def oauth_server_client_id_required(view):
|
||||
|
||||
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_id", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
@@ -30,18 +35,15 @@ def oauth_server_client_id_required(view):
|
||||
if not oauth_provider_app:
|
||||
raise NotFound("client_id is invalid")
|
||||
|
||||
kwargs["oauth_provider_app"] = oauth_provider_app
|
||||
|
||||
return view(*args, **kwargs)
|
||||
return view(self, oauth_provider_app, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def oauth_server_access_token_required(view):
|
||||
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
oauth_provider_app = kwargs.get("oauth_provider_app")
|
||||
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
|
||||
if not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||
raise BadRequest("Invalid oauth_provider_app")
|
||||
|
||||
authorization_header = request.headers.get("Authorization")
|
||||
@@ -79,9 +81,7 @@ def oauth_server_access_token_required(view):
|
||||
response.headers["WWW-Authenticate"] = "Bearer"
|
||||
return response
|
||||
|
||||
kwargs["account"] = account
|
||||
|
||||
return view(*args, **kwargs)
|
||||
return view(self, oauth_provider_app, account, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
@@ -1,4 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate, Optional, ParamSpec, TypeVar
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource
|
||||
@@ -13,19 +15,15 @@ from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
def installed_app_required(view=None):
|
||||
def decorator(view):
|
||||
|
||||
def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
|
||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not kwargs.get("installed_app_id"):
|
||||
raise ValueError("missing installed_app_id in path parameters")
|
||||
|
||||
installed_app_id = kwargs.get("installed_app_id")
|
||||
installed_app_id = str(installed_app_id)
|
||||
|
||||
del kwargs["installed_app_id"]
|
||||
|
||||
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
installed_app = (
|
||||
db.session.query(InstalledApp)
|
||||
.where(
|
||||
@@ -52,10 +50,10 @@ def installed_app_required(view=None):
|
||||
return decorator
|
||||
|
||||
|
||||
def user_allowed_to_access_app(view=None):
|
||||
def decorator(view):
|
||||
def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
|
||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(installed_app: InstalledApp, *args, **kwargs):
|
||||
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
||||
feature = FeatureService.get_system_features()
|
||||
if feature.webapp_auth.enabled:
|
||||
app_id = installed_app.app_id
|
||||
|
@@ -1,4 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -7,14 +9,17 @@ from werkzeug.exceptions import Forbidden
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def plugin_permission_required(
|
||||
install_required: bool = False,
|
||||
debug_required: bool = False,
|
||||
):
|
||||
def interceptor(view):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
|
@@ -2,7 +2,9 @@ import contextlib
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import abort, request
|
||||
from flask_login import current_user
|
||||
@@ -19,10 +21,13 @@ from services.operation_service import OperationService
|
||||
|
||||
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
def account_initialization_required(view):
|
||||
|
||||
def account_initialization_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
# check account initialization
|
||||
account = current_user
|
||||
|
||||
@@ -34,9 +39,9 @@ def account_initialization_required(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_cloud(view):
|
||||
def only_edition_cloud(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if dify_config.EDITION != "CLOUD":
|
||||
abort(404)
|
||||
|
||||
@@ -45,9 +50,9 @@ def only_edition_cloud(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_enterprise(view):
|
||||
def only_edition_enterprise(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
abort(404)
|
||||
|
||||
@@ -56,9 +61,9 @@ def only_edition_enterprise(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_self_hosted(view):
|
||||
def only_edition_self_hosted(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if dify_config.EDITION != "SELF_HOSTED":
|
||||
abort(404)
|
||||
|
||||
@@ -67,9 +72,9 @@ def only_edition_self_hosted(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def cloud_edition_billing_enabled(view):
|
||||
def cloud_edition_billing_enabled(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
if not features.billing.enabled:
|
||||
abort(403, "Billing feature is not enabled.")
|
||||
@@ -79,9 +84,9 @@ def cloud_edition_billing_enabled(view):
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(resource: str):
|
||||
def interceptor(view):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
members = features.members
|
||||
@@ -120,9 +125,9 @@ def cloud_edition_billing_resource_check(resource: str):
|
||||
|
||||
|
||||
def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||
def interceptor(view):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
@@ -142,9 +147,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||
|
||||
|
||||
def cloud_edition_billing_rate_limit_check(resource: str):
|
||||
def interceptor(view):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if resource == "knowledge":
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
@@ -176,9 +181,9 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_utm_record(view):
|
||||
def cloud_utm_record(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
with contextlib.suppress(Exception):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
|
||||
@@ -194,9 +199,9 @@ def cloud_utm_record(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def setup_required(view):
|
||||
def setup_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
# check setup
|
||||
if (
|
||||
dify_config.EDITION == "SELF_HOSTED"
|
||||
@@ -212,9 +217,9 @@ def setup_required(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def enterprise_license_required(view):
|
||||
def enterprise_license_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
settings = FeatureService.get_system_features()
|
||||
if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]:
|
||||
raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.")
|
||||
@@ -224,9 +229,9 @@ def enterprise_license_required(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def email_password_login_enabled(view):
|
||||
def email_password_login_enabled(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if features.enable_email_password_login:
|
||||
return view(*args, **kwargs)
|
||||
@@ -237,9 +242,9 @@ def email_password_login_enabled(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def enable_change_email(view):
|
||||
def enable_change_email(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if features.enable_change_email:
|
||||
return view(*args, **kwargs)
|
||||
@@ -250,9 +255,9 @@ def enable_change_email(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def is_allow_transfer_owner(view):
|
||||
def is_allow_transfer_owner(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
if features.is_allow_transfer_workspace:
|
||||
return view(*args, **kwargs)
|
||||
|
@@ -3,7 +3,7 @@ from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from enum import StrEnum, auto
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
from typing import Optional, ParamSpec, TypeVar
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
@@ -22,6 +22,9 @@ from models.dataset import Dataset, RateLimitLog
|
||||
from models.model import ApiToken, App, EndUser
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class WhereisUserArg(StrEnum):
|
||||
"""
|
||||
@@ -118,8 +121,8 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
||||
def interceptor(view):
|
||||
def decorated(*args, **kwargs):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
features = FeatureService.get_features(api_token.tenant_id)
|
||||
|
||||
@@ -148,9 +151,9 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
||||
|
||||
|
||||
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
|
||||
def interceptor(view):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
features = FeatureService.get_features(api_token.tenant_id)
|
||||
if features.billing.enabled:
|
||||
@@ -170,9 +173,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
|
||||
|
||||
|
||||
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
|
||||
def interceptor(view):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
|
||||
if resource == "knowledge":
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from datetime import UTC, datetime
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@@ -15,6 +16,9 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppSett
|
||||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def validate_jwt_token(view=None):
|
||||
def decorator(view):
|
||||
|
@@ -17,6 +17,10 @@ from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class MatrixoneConfig(BaseModel):
|
||||
|
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Union, cast
|
||||
|
||||
@@ -12,9 +13,13 @@ from models.model import EndUser
|
||||
#: A proxy for the current user. If no user is logged in, this will be an
|
||||
#: anonymous user
|
||||
current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user()))
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def login_required(func):
|
||||
def login_required(func: Callable[P, R]):
|
||||
"""
|
||||
If you decorate a view with this, it will ensure that the current user is
|
||||
logged in and authenticated before calling the actual view. (If they are
|
||||
@@ -49,17 +54,12 @@ def login_required(func):
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
|
||||
pass
|
||||
elif current_user is not None and not current_user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized() # type: ignore
|
||||
|
||||
# flask 1.x compatibility
|
||||
# current_app.ensure_sync is only available in Flask >= 2.0
|
||||
if callable(getattr(current_app, "ensure_sync", None)):
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
Reference in New Issue
Block a user