add more typing (#24949)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato
2025-09-08 11:40:00 +09:00
committed by GitHub
parent ce2281d31b
commit f6059ef389
9 changed files with 97 additions and 74 deletions

View File

@@ -1,4 +1,6 @@
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
@@ -6,6 +8,8 @@ from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
P = ParamSpec("P")
R = TypeVar("R")
from configs import dify_config from configs import dify_config
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import api from controllers.console import api
@@ -14,9 +18,9 @@ from extensions.ext_database import db
from models.model import App, InstalledApp, RecommendedApp from models.model import App, InstalledApp, RecommendedApp
def admin_required(view): def admin_required(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.ADMIN_API_KEY: if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.") raise Unauthorized("API key is invalid.")

View File

@@ -1,5 +1,6 @@
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import cast from typing import Concatenate, ParamSpec, TypeVar, cast
import flask_login import flask_login
from flask import jsonify, request from flask import jsonify, request
@@ -15,10 +16,14 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,
from .. import api from .. import api
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def oauth_server_client_id_required(view):
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("client_id", type=str, required=True, location="json") parser.add_argument("client_id", type=str, required=True, location="json")
parsed_args = parser.parse_args() parsed_args = parser.parse_args()
@@ -30,18 +35,15 @@ def oauth_server_client_id_required(view):
if not oauth_provider_app: if not oauth_provider_app:
raise NotFound("client_id is invalid") raise NotFound("client_id is invalid")
kwargs["oauth_provider_app"] = oauth_provider_app return view(self, oauth_provider_app, *args, **kwargs)
return view(*args, **kwargs)
return decorated return decorated
def oauth_server_access_token_required(view): def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
oauth_provider_app = kwargs.get("oauth_provider_app") if not isinstance(oauth_provider_app, OAuthProviderApp):
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app") raise BadRequest("Invalid oauth_provider_app")
authorization_header = request.headers.get("Authorization") authorization_header = request.headers.get("Authorization")
@@ -79,9 +81,7 @@ def oauth_server_access_token_required(view):
response.headers["WWW-Authenticate"] = "Bearer" response.headers["WWW-Authenticate"] = "Bearer"
return response return response
kwargs["account"] = account return view(self, oauth_provider_app, account, *args, **kwargs)
return view(*args, **kwargs)
return decorated return decorated

View File

@@ -1,4 +1,6 @@
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Concatenate, Optional, ParamSpec, TypeVar
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource from flask_restx import Resource
@@ -13,19 +15,15 @@ from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def installed_app_required(view=None):
def decorator(view): def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
if not kwargs.get("installed_app_id"):
raise ValueError("missing installed_app_id in path parameters")
installed_app_id = kwargs.get("installed_app_id")
installed_app_id = str(installed_app_id)
del kwargs["installed_app_id"]
installed_app = ( installed_app = (
db.session.query(InstalledApp) db.session.query(InstalledApp)
.where( .where(
@@ -52,10 +50,10 @@ def installed_app_required(view=None):
return decorator return decorator
def user_allowed_to_access_app(view=None): def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
def decorator(view): def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view) @wraps(view)
def decorated(installed_app: InstalledApp, *args, **kwargs): def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
feature = FeatureService.get_system_features() feature = FeatureService.get_system_features()
if feature.webapp_auth.enabled: if feature.webapp_auth.enabled:
app_id = installed_app.app_id app_id = installed_app.app_id

View File

@@ -1,4 +1,6 @@
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from flask_login import current_user from flask_login import current_user
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -7,14 +9,17 @@ from werkzeug.exceptions import Forbidden
from extensions.ext_database import db from extensions.ext_database import db
from models.account import TenantPluginPermission from models.account import TenantPluginPermission
P = ParamSpec("P")
R = TypeVar("R")
def plugin_permission_required( def plugin_permission_required(
install_required: bool = False, install_required: bool = False,
debug_required: bool = False, debug_required: bool = False,
): ):
def interceptor(view): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
user = current_user user = current_user
tenant_id = user.current_tenant_id tenant_id = user.current_tenant_id

View File

@@ -2,7 +2,9 @@ import contextlib
import json import json
import os import os
import time import time
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from flask import abort, request from flask import abort, request
from flask_login import current_user from flask_login import current_user
@@ -19,10 +21,13 @@ from services.operation_service import OperationService
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
P = ParamSpec("P")
R = TypeVar("R")
def account_initialization_required(view):
def account_initialization_required(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
# check account initialization # check account initialization
account = current_user account = current_user
@@ -34,9 +39,9 @@ def account_initialization_required(view):
return decorated return decorated
def only_edition_cloud(view): def only_edition_cloud(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if dify_config.EDITION != "CLOUD": if dify_config.EDITION != "CLOUD":
abort(404) abort(404)
@@ -45,9 +50,9 @@ def only_edition_cloud(view):
return decorated return decorated
def only_edition_enterprise(view): def only_edition_enterprise(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.ENTERPRISE_ENABLED: if not dify_config.ENTERPRISE_ENABLED:
abort(404) abort(404)
@@ -56,9 +61,9 @@ def only_edition_enterprise(view):
return decorated return decorated
def only_edition_self_hosted(view): def only_edition_self_hosted(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if dify_config.EDITION != "SELF_HOSTED": if dify_config.EDITION != "SELF_HOSTED":
abort(404) abort(404)
@@ -67,9 +72,9 @@ def only_edition_self_hosted(view):
return decorated return decorated
def cloud_edition_billing_enabled(view): def cloud_edition_billing_enabled(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled: if not features.billing.enabled:
abort(403, "Billing feature is not enabled.") abort(403, "Billing feature is not enabled.")
@@ -79,9 +84,9 @@ def cloud_edition_billing_enabled(view):
def cloud_edition_billing_resource_check(resource: str): def cloud_edition_billing_resource_check(resource: str):
def interceptor(view): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled: if features.billing.enabled:
members = features.members members = features.members
@@ -120,9 +125,9 @@ def cloud_edition_billing_resource_check(resource: str):
def cloud_edition_billing_knowledge_limit_check(resource: str): def cloud_edition_billing_knowledge_limit_check(resource: str):
def interceptor(view): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled: if features.billing.enabled:
if resource == "add_segment": if resource == "add_segment":
@@ -142,9 +147,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
def cloud_edition_billing_rate_limit_check(resource: str): def cloud_edition_billing_rate_limit_check(resource: str):
def interceptor(view): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if resource == "knowledge": if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id) knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
if knowledge_rate_limit.enabled: if knowledge_rate_limit.enabled:
@@ -176,9 +181,9 @@ def cloud_edition_billing_rate_limit_check(resource: str):
return interceptor return interceptor
def cloud_utm_record(view): def cloud_utm_record(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
@@ -194,9 +199,9 @@ def cloud_utm_record(view):
return decorated return decorated
def setup_required(view): def setup_required(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
# check setup # check setup
if ( if (
dify_config.EDITION == "SELF_HOSTED" dify_config.EDITION == "SELF_HOSTED"
@@ -212,9 +217,9 @@ def setup_required(view):
return decorated return decorated
def enterprise_license_required(view): def enterprise_license_required(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
settings = FeatureService.get_system_features() settings = FeatureService.get_system_features()
if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]: if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]:
raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.") raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.")
@@ -224,9 +229,9 @@ def enterprise_license_required(view):
return decorated return decorated
def email_password_login_enabled(view): def email_password_login_enabled(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features() features = FeatureService.get_system_features()
if features.enable_email_password_login: if features.enable_email_password_login:
return view(*args, **kwargs) return view(*args, **kwargs)
@@ -237,9 +242,9 @@ def email_password_login_enabled(view):
return decorated return decorated
def enable_change_email(view): def enable_change_email(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features() features = FeatureService.get_system_features()
if features.enable_change_email: if features.enable_change_email:
return view(*args, **kwargs) return view(*args, **kwargs)
@@ -250,9 +255,9 @@ def enable_change_email(view):
return decorated return decorated
def is_allow_transfer_owner(view): def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
if features.is_allow_transfer_workspace: if features.is_allow_transfer_workspace:
return view(*args, **kwargs) return view(*args, **kwargs)

View File

@@ -3,7 +3,7 @@ from collections.abc import Callable
from datetime import timedelta from datetime import timedelta
from enum import StrEnum, auto from enum import StrEnum, auto
from functools import wraps from functools import wraps
from typing import Optional from typing import Optional, ParamSpec, TypeVar
from flask import current_app, request from flask import current_app, request
from flask_login import user_logged_in from flask_login import user_logged_in
@@ -22,6 +22,9 @@ from models.dataset import Dataset, RateLimitLog
from models.model import ApiToken, App, EndUser from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
class WhereisUserArg(StrEnum): class WhereisUserArg(StrEnum):
""" """
@@ -118,8 +121,8 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
def cloud_edition_billing_resource_check(resource: str, api_token_type: str): def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
def interceptor(view): def interceptor(view: Callable[P, R]):
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token(api_token_type) api_token = validate_and_get_api_token(api_token_type)
features = FeatureService.get_features(api_token.tenant_id) features = FeatureService.get_features(api_token.tenant_id)
@@ -148,9 +151,9 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str): def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
def interceptor(view): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token(api_token_type) api_token = validate_and_get_api_token(api_token_type)
features = FeatureService.get_features(api_token.tenant_id) features = FeatureService.get_features(api_token.tenant_id)
if features.billing.enabled: if features.billing.enabled:
@@ -170,9 +173,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str): def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
def interceptor(view): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token(api_token_type) api_token = validate_and_get_api_token(api_token_type)
if resource == "knowledge": if resource == "knowledge":

View File

@@ -1,5 +1,6 @@
from datetime import UTC, datetime from datetime import UTC, datetime
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
@@ -15,6 +16,9 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppSett
from services.feature_service import FeatureService from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService from services.webapp_auth_service import WebAppAuthService
P = ParamSpec("P")
R = TypeVar("R")
def validate_jwt_token(view=None): def validate_jwt_token(view=None):
def decorator(view): def decorator(view):

View File

@@ -17,6 +17,10 @@ from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
class MatrixoneConfig(BaseModel): class MatrixoneConfig(BaseModel):

View File

@@ -1,3 +1,4 @@
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Union, cast from typing import Union, cast
@@ -12,9 +13,13 @@ from models.model import EndUser
#: A proxy for the current user. If no user is logged in, this will be an #: A proxy for the current user. If no user is logged in, this will be an
#: anonymous user #: anonymous user
current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user())) current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user()))
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
def login_required(func): def login_required(func: Callable[P, R]):
""" """
If you decorate a view with this, it will ensure that the current user is If you decorate a view with this, it will ensure that the current user is
logged in and authenticated before calling the actual view. (If they are logged in and authenticated before calling the actual view. (If they are
@@ -49,17 +54,12 @@ def login_required(func):
""" """
@wraps(func) @wraps(func)
def decorated_view(*args, **kwargs): def decorated_view(*args: P.args, **kwargs: P.kwargs):
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
pass pass
elif current_user is not None and not current_user.is_authenticated: elif current_user is not None and not current_user.is_authenticated:
return current_app.login_manager.unauthorized() # type: ignore return current_app.login_manager.unauthorized() # type: ignore
return current_app.ensure_sync(func)(*args, **kwargs)
# flask 1.x compatibility
# current_app.ensure_sync is only available in Flask >= 2.0
if callable(getattr(current_app, "ensure_sync", None)):
return current_app.ensure_sync(func)(*args, **kwargs)
return func(*args, **kwargs)
return decorated_view return decorated_view