add more typing (#24949)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato
2025-09-08 11:40:00 +09:00
committed by GitHub
parent ce2281d31b
commit f6059ef389
9 changed files with 97 additions and 74 deletions

View File

@@ -1,4 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request
from flask_restx import Resource, reqparse
@@ -6,6 +8,8 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized
P = ParamSpec("P")
R = TypeVar("R")
from configs import dify_config
from constants.languages import supported_language
from controllers.console import api
@@ -14,9 +18,9 @@ from extensions.ext_database import db
from models.model import App, InstalledApp, RecommendedApp
def admin_required(view):
def admin_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.")

View File

@@ -1,5 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import cast
from typing import Concatenate, ParamSpec, TypeVar, cast
import flask_login
from flask import jsonify, request
@@ -15,10 +16,14 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,
from .. import api
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def oauth_server_client_id_required(view):
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
parser = reqparse.RequestParser()
parser.add_argument("client_id", type=str, required=True, location="json")
parsed_args = parser.parse_args()
@@ -30,18 +35,15 @@ def oauth_server_client_id_required(view):
if not oauth_provider_app:
raise NotFound("client_id is invalid")
kwargs["oauth_provider_app"] = oauth_provider_app
return view(*args, **kwargs)
return view(self, oauth_provider_app, *args, **kwargs)
return decorated
def oauth_server_access_token_required(view):
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
@wraps(view)
def decorated(*args, **kwargs):
oauth_provider_app = kwargs.get("oauth_provider_app")
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
if not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app")
authorization_header = request.headers.get("Authorization")
@@ -79,9 +81,7 @@ def oauth_server_access_token_required(view):
response.headers["WWW-Authenticate"] = "Bearer"
return response
kwargs["account"] = account
return view(*args, **kwargs)
return view(self, oauth_provider_app, account, *args, **kwargs)
return decorated

View File

@@ -1,4 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import Concatenate, Optional, ParamSpec, TypeVar
from flask_login import current_user
from flask_restx import Resource
@@ -13,19 +15,15 @@ from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def installed_app_required(view=None):
def decorator(view):
def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(*args, **kwargs):
if not kwargs.get("installed_app_id"):
raise ValueError("missing installed_app_id in path parameters")
installed_app_id = kwargs.get("installed_app_id")
installed_app_id = str(installed_app_id)
del kwargs["installed_app_id"]
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
installed_app = (
db.session.query(InstalledApp)
.where(
@@ -52,10 +50,10 @@ def installed_app_required(view=None):
return decorator
def user_allowed_to_access_app(view=None):
def decorator(view):
def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(installed_app: InstalledApp, *args, **kwargs):
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
feature = FeatureService.get_system_features()
if feature.webapp_auth.enabled:
app_id = installed_app.app_id

View File

@@ -1,4 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask_login import current_user
from sqlalchemy.orm import Session
@@ -7,14 +9,17 @@ from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from models.account import TenantPluginPermission
P = ParamSpec("P")
R = TypeVar("R")
def plugin_permission_required(
install_required: bool = False,
debug_required: bool = False,
):
def interceptor(view):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
user = current_user
tenant_id = user.current_tenant_id

View File

@@ -2,7 +2,9 @@ import contextlib
import json
import os
import time
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import abort, request
from flask_login import current_user
@@ -19,10 +21,13 @@ from services.operation_service import OperationService
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
P = ParamSpec("P")
R = TypeVar("R")
def account_initialization_required(view):
def account_initialization_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
# check account initialization
account = current_user
@@ -34,9 +39,9 @@ def account_initialization_required(view):
return decorated
def only_edition_cloud(view):
def only_edition_cloud(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if dify_config.EDITION != "CLOUD":
abort(404)
@@ -45,9 +50,9 @@ def only_edition_cloud(view):
return decorated
def only_edition_enterprise(view):
def only_edition_enterprise(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.ENTERPRISE_ENABLED:
abort(404)
@@ -56,9 +61,9 @@ def only_edition_enterprise(view):
return decorated
def only_edition_self_hosted(view):
def only_edition_self_hosted(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if dify_config.EDITION != "SELF_HOSTED":
abort(404)
@@ -67,9 +72,9 @@ def only_edition_self_hosted(view):
return decorated
def cloud_edition_billing_enabled(view):
def cloud_edition_billing_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled:
abort(403, "Billing feature is not enabled.")
@@ -79,9 +84,9 @@ def cloud_edition_billing_enabled(view):
def cloud_edition_billing_resource_check(resource: str):
def interceptor(view):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
members = features.members
@@ -120,9 +125,9 @@ def cloud_edition_billing_resource_check(resource: str):
def cloud_edition_billing_knowledge_limit_check(resource: str):
def interceptor(view):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
if resource == "add_segment":
@@ -142,9 +147,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
def cloud_edition_billing_rate_limit_check(resource: str):
def interceptor(view):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
if knowledge_rate_limit.enabled:
@@ -176,9 +181,9 @@ def cloud_edition_billing_rate_limit_check(resource: str):
return interceptor
def cloud_utm_record(view):
def cloud_utm_record(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception):
features = FeatureService.get_features(current_user.current_tenant_id)
@@ -194,9 +199,9 @@ def cloud_utm_record(view):
return decorated
def setup_required(view):
def setup_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
# check setup
if (
dify_config.EDITION == "SELF_HOSTED"
@@ -212,9 +217,9 @@ def setup_required(view):
return decorated
def enterprise_license_required(view):
def enterprise_license_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
settings = FeatureService.get_system_features()
if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]:
raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.")
@@ -224,9 +229,9 @@ def enterprise_license_required(view):
return decorated
def email_password_login_enabled(view):
def email_password_login_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()
if features.enable_email_password_login:
return view(*args, **kwargs)
@@ -237,9 +242,9 @@ def email_password_login_enabled(view):
return decorated
def enable_change_email(view):
def enable_change_email(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()
if features.enable_change_email:
return view(*args, **kwargs)
@@ -250,9 +255,9 @@ def enable_change_email(view):
return decorated
def is_allow_transfer_owner(view):
def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
if features.is_allow_transfer_workspace:
return view(*args, **kwargs)

View File

@@ -3,7 +3,7 @@ from collections.abc import Callable
from datetime import timedelta
from enum import StrEnum, auto
from functools import wraps
from typing import Optional
from typing import Optional, ParamSpec, TypeVar
from flask import current_app, request
from flask_login import user_logged_in
@@ -22,6 +22,9 @@ from models.dataset import Dataset, RateLimitLog
from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
class WhereisUserArg(StrEnum):
"""
@@ -118,8 +121,8 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
def interceptor(view):
def decorated(*args, **kwargs):
def interceptor(view: Callable[P, R]):
def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token(api_token_type)
features = FeatureService.get_features(api_token.tenant_id)
@@ -148,9 +151,9 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
def interceptor(view):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token(api_token_type)
features = FeatureService.get_features(api_token.tenant_id)
if features.billing.enabled:
@@ -170,9 +173,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
def interceptor(view):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token(api_token_type)
if resource == "knowledge":

View File

@@ -1,5 +1,6 @@
from datetime import UTC, datetime
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request
from flask_restx import Resource
@@ -15,6 +16,9 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppSett
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService
P = ParamSpec("P")
R = TypeVar("R")
def validate_jwt_token(view=None):
def decorator(view):

View File

@@ -17,6 +17,10 @@ from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
class MatrixoneConfig(BaseModel):

View File

@@ -1,3 +1,4 @@
from collections.abc import Callable
from functools import wraps
from typing import Union, cast
@@ -12,9 +13,13 @@ from models.model import EndUser
#: A proxy for the current user. If no user is logged in, this will be an
#: anonymous user
current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user()))
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
def login_required(func):
def login_required(func: Callable[P, R]):
"""
If you decorate a view with this, it will ensure that the current user is
logged in and authenticated before calling the actual view. (If they are
@@ -49,17 +54,12 @@ def login_required(func):
"""
@wraps(func)
def decorated_view(*args, **kwargs):
def decorated_view(*args: P.args, **kwargs: P.kwargs):
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
pass
elif current_user is not None and not current_user.is_authenticated:
return current_app.login_manager.unauthorized() # type: ignore
# flask 1.x compatibility
# current_app.ensure_sync is only available in Flask >= 2.0
if callable(getattr(current_app, "ensure_sync", None)):
return current_app.ensure_sync(func)(*args, **kwargs)
return func(*args, **kwargs)
return current_app.ensure_sync(func)(*args, **kwargs)
return decorated_view