From f6059ef38991abc87acf2739fa8492bd1779fc6a Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 8 Sep 2025 11:40:00 +0900 Subject: [PATCH] add more typing (#24949) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/admin.py | 8 ++- api/controllers/console/auth/oauth_server.py | 26 ++++---- api/controllers/console/explore/wraps.py | 26 ++++---- api/controllers/console/workspace/__init__.py | 9 ++- api/controllers/console/wraps.py | 61 ++++++++++--------- api/controllers/service_api/wraps.py | 17 +++--- api/controllers/web/wraps.py | 4 ++ .../vdb/matrixone/matrixone_vector.py | 4 ++ api/libs/login.py | 16 ++--- 9 files changed, 97 insertions(+), 74 deletions(-) diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index cae2d7cbe..1306efacf 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -1,4 +1,6 @@ +from collections.abc import Callable from functools import wraps +from typing import ParamSpec, TypeVar from flask import request from flask_restx import Resource, reqparse @@ -6,6 +8,8 @@ from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound, Unauthorized +P = ParamSpec("P") +R = TypeVar("R") from configs import dify_config from constants.languages import supported_language from controllers.console import api @@ -14,9 +18,9 @@ from extensions.ext_database import db from models.model import App, InstalledApp, RecommendedApp -def admin_required(view): +def admin_required(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): if not dify_config.ADMIN_API_KEY: raise Unauthorized("API key is invalid.") diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index a8ba41784..a54c1443f 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -1,5 +1,6 @@ +from collections.abc import Callable from functools import wraps -from typing import cast +from typing import Concatenate, ParamSpec, TypeVar, cast import flask_login from flask import jsonify, request @@ -15,10 +16,14 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, from .. import api +P = ParamSpec("P") +R = TypeVar("R") +T = TypeVar("T") -def oauth_server_client_id_required(view): + +def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(self: T, *args: P.args, **kwargs: P.kwargs): parser = reqparse.RequestParser() parser.add_argument("client_id", type=str, required=True, location="json") parsed_args = parser.parse_args() @@ -30,18 +35,15 @@ def oauth_server_client_id_required(view): if not oauth_provider_app: raise NotFound("client_id is invalid") - kwargs["oauth_provider_app"] = oauth_provider_app - - return view(*args, **kwargs) + return view(self, oauth_provider_app, *args, **kwargs) return decorated -def oauth_server_access_token_required(view): +def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]): @wraps(view) - def decorated(*args, **kwargs): - oauth_provider_app = kwargs.get("oauth_provider_app") - if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp): + def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs): + if not isinstance(oauth_provider_app, OAuthProviderApp): raise BadRequest("Invalid oauth_provider_app") authorization_header = request.headers.get("Authorization") @@ -79,9 +81,7 @@ def oauth_server_access_token_required(view): response.headers["WWW-Authenticate"] = "Bearer" return response - kwargs["account"] = account - - return view(*args, **kwargs) + return view(self, oauth_provider_app, account, *args, **kwargs) return decorated diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index e86103184..6401f804c 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -1,4 +1,6 @@ +from collections.abc import Callable from functools import wraps +from typing import Concatenate, Optional, ParamSpec, TypeVar from flask_login import current_user from flask_restx import Resource @@ -13,19 +15,15 @@ from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService +P = ParamSpec("P") +R = TypeVar("R") +T = TypeVar("T") -def installed_app_required(view=None): - def decorator(view): + +def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None): + def decorator(view: Callable[Concatenate[InstalledApp, P], R]): @wraps(view) - def decorated(*args, **kwargs): - if not kwargs.get("installed_app_id"): - raise ValueError("missing installed_app_id in path parameters") - - installed_app_id = kwargs.get("installed_app_id") - installed_app_id = str(installed_app_id) - - del kwargs["installed_app_id"] - + def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs): installed_app = ( db.session.query(InstalledApp) .where( @@ -52,10 +50,10 @@ def installed_app_required(view=None): return decorator -def user_allowed_to_access_app(view=None): - def decorator(view): +def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None): + def decorator(view: Callable[Concatenate[InstalledApp, P], R]): @wraps(view) - def decorated(installed_app: InstalledApp, *args, **kwargs): + def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs): feature = FeatureService.get_system_features() if feature.webapp_auth.enabled: app_id = installed_app.app_id diff --git a/api/controllers/console/workspace/__init__.py b/api/controllers/console/workspace/__init__.py index ef814dd73..4a048f3c5 100644 --- a/api/controllers/console/workspace/__init__.py +++ b/api/controllers/console/workspace/__init__.py @@ -1,4 +1,6 @@ +from collections.abc import Callable from functools import wraps +from typing import ParamSpec, TypeVar from flask_login import current_user from sqlalchemy.orm import Session @@ -7,14 +9,17 @@ from werkzeug.exceptions import Forbidden from extensions.ext_database import db from models.account import TenantPluginPermission +P = ParamSpec("P") +R = TypeVar("R") + def plugin_permission_required( install_required: bool = False, debug_required: bool = False, ): - def interceptor(view): + def interceptor(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): user = current_user tenant_id = user.current_tenant_id diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index d3fd1d52e..e375fe285 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -2,7 +2,9 @@ import contextlib import json import os import time +from collections.abc import Callable from functools import wraps +from typing import ParamSpec, TypeVar from flask import abort, request from flask_login import current_user @@ -19,10 +21,13 @@ from services.operation_service import OperationService from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout +P = ParamSpec("P") +R = TypeVar("R") -def account_initialization_required(view): + +def account_initialization_required(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): # check account initialization account = current_user @@ -34,9 +39,9 @@ def account_initialization_required(view): return decorated -def only_edition_cloud(view): +def only_edition_cloud(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): if dify_config.EDITION != "CLOUD": abort(404) @@ -45,9 +50,9 @@ def only_edition_cloud(view): return decorated -def only_edition_enterprise(view): +def only_edition_enterprise(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): if not dify_config.ENTERPRISE_ENABLED: abort(404) @@ -56,9 +61,9 @@ def only_edition_enterprise(view): return decorated -def only_edition_self_hosted(view): +def only_edition_self_hosted(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): if dify_config.EDITION != "SELF_HOSTED": abort(404) @@ -67,9 +72,9 @@ def only_edition_self_hosted(view): return decorated -def cloud_edition_billing_enabled(view): +def cloud_edition_billing_enabled(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): features = FeatureService.get_features(current_user.current_tenant_id) if not features.billing.enabled: abort(403, "Billing feature is not enabled.") @@ -79,9 +84,9 @@ def cloud_edition_billing_enabled(view): def cloud_edition_billing_resource_check(resource: str): - def interceptor(view): + def interceptor(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: members = features.members @@ -120,9 +125,9 @@ def cloud_edition_billing_resource_check(resource: str): def cloud_edition_billing_knowledge_limit_check(resource: str): - def interceptor(view): + def interceptor(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: if resource == "add_segment": @@ -142,9 +147,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str): def cloud_edition_billing_rate_limit_check(resource: str): - def interceptor(view): + def interceptor(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): if resource == "knowledge": knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id) if knowledge_rate_limit.enabled: @@ -176,9 +181,9 @@ def cloud_edition_billing_rate_limit_check(resource: str): return interceptor -def cloud_utm_record(view): +def cloud_utm_record(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): with contextlib.suppress(Exception): features = FeatureService.get_features(current_user.current_tenant_id) @@ -194,9 +199,9 @@ def cloud_utm_record(view): return decorated -def setup_required(view): +def setup_required(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): # check setup if ( dify_config.EDITION == "SELF_HOSTED" @@ -212,9 +217,9 @@ def setup_required(view): return decorated -def enterprise_license_required(view): +def enterprise_license_required(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): settings = FeatureService.get_system_features() if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]: raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.") @@ -224,9 +229,9 @@ def enterprise_license_required(view): return decorated -def email_password_login_enabled(view): +def email_password_login_enabled(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): features = FeatureService.get_system_features() if features.enable_email_password_login: return view(*args, **kwargs) @@ -237,9 +242,9 @@ def email_password_login_enabled(view): return decorated -def enable_change_email(view): +def enable_change_email(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): features = FeatureService.get_system_features() if features.enable_change_email: return view(*args, **kwargs) @@ -250,9 +255,9 @@ def enable_change_email(view): return decorated -def is_allow_transfer_owner(view): +def is_allow_transfer_owner(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): features = FeatureService.get_features(current_user.current_tenant_id) if features.is_allow_transfer_workspace: return view(*args, **kwargs) diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 67d48319d..4d71e5839 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -3,7 +3,7 @@ from collections.abc import Callable from datetime import timedelta from enum import StrEnum, auto from functools import wraps -from typing import Optional +from typing import Optional, ParamSpec, TypeVar from flask import current_app, request from flask_login import user_logged_in @@ -22,6 +22,9 @@ from models.dataset import Dataset, RateLimitLog from models.model import ApiToken, App, EndUser from services.feature_service import FeatureService +P = ParamSpec("P") +R = TypeVar("R") + class WhereisUserArg(StrEnum): """ @@ -118,8 +121,8 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio def cloud_edition_billing_resource_check(resource: str, api_token_type: str): - def interceptor(view): - def decorated(*args, **kwargs): + def interceptor(view: Callable[P, R]): + def decorated(*args: P.args, **kwargs: P.kwargs): api_token = validate_and_get_api_token(api_token_type) features = FeatureService.get_features(api_token.tenant_id) @@ -148,9 +151,9 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str): def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str): - def interceptor(view): + def interceptor(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): api_token = validate_and_get_api_token(api_token_type) features = FeatureService.get_features(api_token.tenant_id) if features.billing.enabled: @@ -170,9 +173,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str): - def interceptor(view): + def interceptor(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): api_token = validate_and_get_api_token(api_token_type) if resource == "knowledge": diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 1fc8916ca..1fbb2c165 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -1,5 +1,6 @@ from datetime import UTC, datetime from functools import wraps +from typing import ParamSpec, TypeVar from flask import request from flask_restx import Resource @@ -15,6 +16,9 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppSett from services.feature_service import FeatureService from services.webapp_auth_service import WebAppAuthService +P = ParamSpec("P") +R = TypeVar("R") + def validate_jwt_token(view=None): def decorator(view): diff --git a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py index 9660cf8ab..7da830f64 100644 --- a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py +++ b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py @@ -17,6 +17,10 @@ from extensions.ext_redis import redis_client from models.dataset import Dataset logger = logging.getLogger(__name__) +from typing import ParamSpec, TypeVar + +P = ParamSpec("P") +R = TypeVar("R") class MatrixoneConfig(BaseModel): diff --git a/api/libs/login.py b/api/libs/login.py index 711d16e3b..0535f52ea 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from functools import wraps from typing import Union, cast @@ -12,9 +13,13 @@ from models.model import EndUser #: A proxy for the current user. If no user is logged in, this will be an #: anonymous user current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user())) +from typing import ParamSpec, TypeVar + +P = ParamSpec("P") +R = TypeVar("R") -def login_required(func): +def login_required(func: Callable[P, R]): """ If you decorate a view with this, it will ensure that the current user is logged in and authenticated before calling the actual view. (If they are @@ -49,17 +54,12 @@ def login_required(func): """ @wraps(func) - def decorated_view(*args, **kwargs): + def decorated_view(*args: P.args, **kwargs: P.kwargs): if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: pass elif current_user is not None and not current_user.is_authenticated: return current_app.login_manager.unauthorized() # type: ignore - - # flask 1.x compatibility - # current_app.ensure_sync is only available in Flask >= 2.0 - if callable(getattr(current_app, "ensure_sync", None)): - return current_app.ensure_sync(func)(*args, **kwargs) - return func(*args, **kwargs) + return current_app.ensure_sync(func)(*args, **kwargs) return decorated_view