add more typing (#24949)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,4 +1,6 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from typing import ParamSpec, TypeVar
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
@@ -6,6 +8,8 @@ from sqlalchemy import select
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import NotFound, Unauthorized
|
from werkzeug.exceptions import NotFound, Unauthorized
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants.languages import supported_language
|
from constants.languages import supported_language
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
@@ -14,9 +18,9 @@ from extensions.ext_database import db
|
|||||||
from models.model import App, InstalledApp, RecommendedApp
|
from models.model import App, InstalledApp, RecommendedApp
|
||||||
|
|
||||||
|
|
||||||
def admin_required(view):
|
def admin_required(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
if not dify_config.ADMIN_API_KEY:
|
if not dify_config.ADMIN_API_KEY:
|
||||||
raise Unauthorized("API key is invalid.")
|
raise Unauthorized("API key is invalid.")
|
||||||
|
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import cast
|
from typing import Concatenate, ParamSpec, TypeVar, cast
|
||||||
|
|
||||||
import flask_login
|
import flask_login
|
||||||
from flask import jsonify, request
|
from flask import jsonify, request
|
||||||
@@ -15,10 +16,14 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,
|
|||||||
|
|
||||||
from .. import api
|
from .. import api
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
def oauth_server_client_id_required(view):
|
|
||||||
|
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("client_id", type=str, required=True, location="json")
|
parser.add_argument("client_id", type=str, required=True, location="json")
|
||||||
parsed_args = parser.parse_args()
|
parsed_args = parser.parse_args()
|
||||||
@@ -30,18 +35,15 @@ def oauth_server_client_id_required(view):
|
|||||||
if not oauth_provider_app:
|
if not oauth_provider_app:
|
||||||
raise NotFound("client_id is invalid")
|
raise NotFound("client_id is invalid")
|
||||||
|
|
||||||
kwargs["oauth_provider_app"] = oauth_provider_app
|
return view(self, oauth_provider_app, *args, **kwargs)
|
||||||
|
|
||||||
return view(*args, **kwargs)
|
|
||||||
|
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def oauth_server_access_token_required(view):
|
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
|
||||||
oauth_provider_app = kwargs.get("oauth_provider_app")
|
if not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||||
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
|
|
||||||
raise BadRequest("Invalid oauth_provider_app")
|
raise BadRequest("Invalid oauth_provider_app")
|
||||||
|
|
||||||
authorization_header = request.headers.get("Authorization")
|
authorization_header = request.headers.get("Authorization")
|
||||||
@@ -79,9 +81,7 @@ def oauth_server_access_token_required(view):
|
|||||||
response.headers["WWW-Authenticate"] = "Bearer"
|
response.headers["WWW-Authenticate"] = "Bearer"
|
||||||
return response
|
return response
|
||||||
|
|
||||||
kwargs["account"] = account
|
return view(self, oauth_provider_app, account, *args, **kwargs)
|
||||||
|
|
||||||
return view(*args, **kwargs)
|
|
||||||
|
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
@@ -1,4 +1,6 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from typing import Concatenate, Optional, ParamSpec, TypeVar
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
@@ -13,19 +15,15 @@ from services.app_service import AppService
|
|||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
def installed_app_required(view=None):
|
|
||||||
def decorator(view):
|
def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
|
||||||
|
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||||
if not kwargs.get("installed_app_id"):
|
|
||||||
raise ValueError("missing installed_app_id in path parameters")
|
|
||||||
|
|
||||||
installed_app_id = kwargs.get("installed_app_id")
|
|
||||||
installed_app_id = str(installed_app_id)
|
|
||||||
|
|
||||||
del kwargs["installed_app_id"]
|
|
||||||
|
|
||||||
installed_app = (
|
installed_app = (
|
||||||
db.session.query(InstalledApp)
|
db.session.query(InstalledApp)
|
||||||
.where(
|
.where(
|
||||||
@@ -52,10 +50,10 @@ def installed_app_required(view=None):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def user_allowed_to_access_app(view=None):
|
def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
|
||||||
def decorator(view):
|
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(installed_app: InstalledApp, *args, **kwargs):
|
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
||||||
feature = FeatureService.get_system_features()
|
feature = FeatureService.get_system_features()
|
||||||
if feature.webapp_auth.enabled:
|
if feature.webapp_auth.enabled:
|
||||||
app_id = installed_app.app_id
|
app_id = installed_app.app_id
|
||||||
|
@@ -1,4 +1,6 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from typing import ParamSpec, TypeVar
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -7,14 +9,17 @@ from werkzeug.exceptions import Forbidden
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import TenantPluginPermission
|
from models.account import TenantPluginPermission
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
def plugin_permission_required(
|
def plugin_permission_required(
|
||||||
install_required: bool = False,
|
install_required: bool = False,
|
||||||
debug_required: bool = False,
|
debug_required: bool = False,
|
||||||
):
|
):
|
||||||
def interceptor(view):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
user = current_user
|
user = current_user
|
||||||
tenant_id = user.current_tenant_id
|
tenant_id = user.current_tenant_id
|
||||||
|
|
||||||
|
@@ -2,7 +2,9 @@ import contextlib
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from typing import ParamSpec, TypeVar
|
||||||
|
|
||||||
from flask import abort, request
|
from flask import abort, request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
@@ -19,10 +21,13 @@ from services.operation_service import OperationService
|
|||||||
|
|
||||||
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
def account_initialization_required(view):
|
|
||||||
|
def account_initialization_required(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
# check account initialization
|
# check account initialization
|
||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
@@ -34,9 +39,9 @@ def account_initialization_required(view):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def only_edition_cloud(view):
|
def only_edition_cloud(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
if dify_config.EDITION != "CLOUD":
|
if dify_config.EDITION != "CLOUD":
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
@@ -45,9 +50,9 @@ def only_edition_cloud(view):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def only_edition_enterprise(view):
|
def only_edition_enterprise(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
if not dify_config.ENTERPRISE_ENABLED:
|
if not dify_config.ENTERPRISE_ENABLED:
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
@@ -56,9 +61,9 @@ def only_edition_enterprise(view):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def only_edition_self_hosted(view):
|
def only_edition_self_hosted(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
if dify_config.EDITION != "SELF_HOSTED":
|
if dify_config.EDITION != "SELF_HOSTED":
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
@@ -67,9 +72,9 @@ def only_edition_self_hosted(view):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_enabled(view):
|
def cloud_edition_billing_enabled(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
if not features.billing.enabled:
|
if not features.billing.enabled:
|
||||||
abort(403, "Billing feature is not enabled.")
|
abort(403, "Billing feature is not enabled.")
|
||||||
@@ -79,9 +84,9 @@ def cloud_edition_billing_enabled(view):
|
|||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_resource_check(resource: str):
|
def cloud_edition_billing_resource_check(resource: str):
|
||||||
def interceptor(view):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
members = features.members
|
members = features.members
|
||||||
@@ -120,9 +125,9 @@ def cloud_edition_billing_resource_check(resource: str):
|
|||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_knowledge_limit_check(resource: str):
|
def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||||
def interceptor(view):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
if resource == "add_segment":
|
if resource == "add_segment":
|
||||||
@@ -142,9 +147,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
|||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_rate_limit_check(resource: str):
|
def cloud_edition_billing_rate_limit_check(resource: str):
|
||||||
def interceptor(view):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
if resource == "knowledge":
|
if resource == "knowledge":
|
||||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
|
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
|
||||||
if knowledge_rate_limit.enabled:
|
if knowledge_rate_limit.enabled:
|
||||||
@@ -176,9 +181,9 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||||||
return interceptor
|
return interceptor
|
||||||
|
|
||||||
|
|
||||||
def cloud_utm_record(view):
|
def cloud_utm_record(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
|
|
||||||
@@ -194,9 +199,9 @@ def cloud_utm_record(view):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def setup_required(view):
|
def setup_required(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
# check setup
|
# check setup
|
||||||
if (
|
if (
|
||||||
dify_config.EDITION == "SELF_HOSTED"
|
dify_config.EDITION == "SELF_HOSTED"
|
||||||
@@ -212,9 +217,9 @@ def setup_required(view):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def enterprise_license_required(view):
|
def enterprise_license_required(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
settings = FeatureService.get_system_features()
|
settings = FeatureService.get_system_features()
|
||||||
if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]:
|
if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]:
|
||||||
raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.")
|
raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.")
|
||||||
@@ -224,9 +229,9 @@ def enterprise_license_required(view):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def email_password_login_enabled(view):
|
def email_password_login_enabled(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
features = FeatureService.get_system_features()
|
features = FeatureService.get_system_features()
|
||||||
if features.enable_email_password_login:
|
if features.enable_email_password_login:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
@@ -237,9 +242,9 @@ def email_password_login_enabled(view):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def enable_change_email(view):
|
def enable_change_email(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
features = FeatureService.get_system_features()
|
features = FeatureService.get_system_features()
|
||||||
if features.enable_change_email:
|
if features.enable_change_email:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
@@ -250,9 +255,9 @@ def enable_change_email(view):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def is_allow_transfer_owner(view):
|
def is_allow_transfer_owner(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
if features.is_allow_transfer_workspace:
|
if features.is_allow_transfer_workspace:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
@@ -3,7 +3,7 @@ from collections.abc import Callable
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Optional
|
from typing import Optional, ParamSpec, TypeVar
|
||||||
|
|
||||||
from flask import current_app, request
|
from flask import current_app, request
|
||||||
from flask_login import user_logged_in
|
from flask_login import user_logged_in
|
||||||
@@ -22,6 +22,9 @@ from models.dataset import Dataset, RateLimitLog
|
|||||||
from models.model import ApiToken, App, EndUser
|
from models.model import ApiToken, App, EndUser
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
class WhereisUserArg(StrEnum):
|
class WhereisUserArg(StrEnum):
|
||||||
"""
|
"""
|
||||||
@@ -118,8 +121,8 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
|
|||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
||||||
def interceptor(view):
|
def interceptor(view: Callable[P, R]):
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
api_token = validate_and_get_api_token(api_token_type)
|
api_token = validate_and_get_api_token(api_token_type)
|
||||||
features = FeatureService.get_features(api_token.tenant_id)
|
features = FeatureService.get_features(api_token.tenant_id)
|
||||||
|
|
||||||
@@ -148,9 +151,9 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
|||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
|
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
|
||||||
def interceptor(view):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
api_token = validate_and_get_api_token(api_token_type)
|
api_token = validate_and_get_api_token(api_token_type)
|
||||||
features = FeatureService.get_features(api_token.tenant_id)
|
features = FeatureService.get_features(api_token.tenant_id)
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
@@ -170,9 +173,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
|
|||||||
|
|
||||||
|
|
||||||
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
|
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
|
||||||
def interceptor(view):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
api_token = validate_and_get_api_token(api_token_type)
|
api_token = validate_and_get_api_token(api_token_type)
|
||||||
|
|
||||||
if resource == "knowledge":
|
if resource == "knowledge":
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from typing import ParamSpec, TypeVar
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
@@ -15,6 +16,9 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppSett
|
|||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
from services.webapp_auth_service import WebAppAuthService
|
from services.webapp_auth_service import WebAppAuthService
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
def validate_jwt_token(view=None):
|
def validate_jwt_token(view=None):
|
||||||
def decorator(view):
|
def decorator(view):
|
||||||
|
@@ -17,6 +17,10 @@ from extensions.ext_redis import redis_client
|
|||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
from typing import ParamSpec, TypeVar
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
class MatrixoneConfig(BaseModel):
|
class MatrixoneConfig(BaseModel):
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Union, cast
|
from typing import Union, cast
|
||||||
|
|
||||||
@@ -12,9 +13,13 @@ from models.model import EndUser
|
|||||||
#: A proxy for the current user. If no user is logged in, this will be an
|
#: A proxy for the current user. If no user is logged in, this will be an
|
||||||
#: anonymous user
|
#: anonymous user
|
||||||
current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user()))
|
current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user()))
|
||||||
|
from typing import ParamSpec, TypeVar
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
def login_required(func):
|
def login_required(func: Callable[P, R]):
|
||||||
"""
|
"""
|
||||||
If you decorate a view with this, it will ensure that the current user is
|
If you decorate a view with this, it will ensure that the current user is
|
||||||
logged in and authenticated before calling the actual view. (If they are
|
logged in and authenticated before calling the actual view. (If they are
|
||||||
@@ -49,17 +54,12 @@ def login_required(func):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def decorated_view(*args, **kwargs):
|
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||||
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
|
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
|
||||||
pass
|
pass
|
||||||
elif current_user is not None and not current_user.is_authenticated:
|
elif current_user is not None and not current_user.is_authenticated:
|
||||||
return current_app.login_manager.unauthorized() # type: ignore
|
return current_app.login_manager.unauthorized() # type: ignore
|
||||||
|
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||||
# flask 1.x compatibility
|
|
||||||
# current_app.ensure_sync is only available in Flask >= 2.0
|
|
||||||
if callable(getattr(current_app, "ensure_sync", None)):
|
|
||||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
return decorated_view
|
return decorated_view
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user