add more typing (#24949)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import cast
|
||||
from typing import Concatenate, ParamSpec, TypeVar, cast
|
||||
|
||||
import flask_login
|
||||
from flask import jsonify, request
|
||||
@@ -15,10 +16,14 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,
|
||||
|
||||
from .. import api
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
def oauth_server_client_id_required(view):
|
||||
|
||||
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_id", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
@@ -30,18 +35,15 @@ def oauth_server_client_id_required(view):
|
||||
if not oauth_provider_app:
|
||||
raise NotFound("client_id is invalid")
|
||||
|
||||
kwargs["oauth_provider_app"] = oauth_provider_app
|
||||
|
||||
return view(*args, **kwargs)
|
||||
return view(self, oauth_provider_app, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def oauth_server_access_token_required(view):
|
||||
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
oauth_provider_app = kwargs.get("oauth_provider_app")
|
||||
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
|
||||
if not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||
raise BadRequest("Invalid oauth_provider_app")
|
||||
|
||||
authorization_header = request.headers.get("Authorization")
|
||||
@@ -79,9 +81,7 @@ def oauth_server_access_token_required(view):
|
||||
response.headers["WWW-Authenticate"] = "Bearer"
|
||||
return response
|
||||
|
||||
kwargs["account"] = account
|
||||
|
||||
return view(*args, **kwargs)
|
||||
return view(self, oauth_provider_app, account, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
Reference in New Issue
Block a user