example: limit current user usage (#24470)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato
2025-08-26 01:23:29 +09:00
committed by GitHub
parent b4be132201
commit 2b91ba2411
8 changed files with 81 additions and 41 deletions

View File

@@ -72,6 +72,7 @@ class DraftWorkflowApi(Resource):
Get draft workflow Get draft workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account)
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
@@ -94,6 +95,7 @@ class DraftWorkflowApi(Resource):
Sync draft workflow Sync draft workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account)
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
@@ -171,6 +173,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
Run draft workflow Run draft workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account)
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
@@ -218,13 +221,12 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
""" """
Run draft workflow iteration node Run draft workflow iteration node
""" """
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json") parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args() args = parser.parse_args()
@@ -256,11 +258,10 @@ class WorkflowDraftRunIterationNodeApi(Resource):
Run draft workflow iteration node Run draft workflow iteration node
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json") parser.add_argument("inputs", type=dict, location="json")
@@ -292,12 +293,12 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
""" """
Run draft workflow loop node Run draft workflow loop node
""" """
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json") parser.add_argument("inputs", type=dict, location="json")
@@ -329,12 +330,12 @@ class WorkflowDraftRunLoopNodeApi(Resource):
""" """
Run draft workflow loop node Run draft workflow loop node
""" """
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json") parser.add_argument("inputs", type=dict, location="json")
@@ -366,12 +367,12 @@ class DraftWorkflowRunApi(Resource):
""" """
Run draft workflow Run draft workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
@@ -405,6 +406,9 @@ class WorkflowTaskStopApi(Resource):
""" """
Stop workflow task Stop workflow task
""" """
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
@@ -424,12 +428,12 @@ class DraftWorkflowNodeRunApi(Resource):
""" """
Run draft workflow node Run draft workflow node
""" """
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
@@ -472,6 +476,9 @@ class PublishedWorkflowApi(Resource):
""" """
Get published workflow Get published workflow
""" """
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
@@ -491,13 +498,12 @@ class PublishedWorkflowApi(Resource):
""" """
Publish workflow Publish workflow
""" """
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, default="", location="json") parser.add_argument("marked_name", type=str, required=False, default="", location="json")
parser.add_argument("marked_comment", type=str, required=False, default="", location="json") parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
@@ -541,6 +547,9 @@ class DefaultBlockConfigsApi(Resource):
""" """
Get default block config Get default block config
""" """
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
@@ -559,13 +568,12 @@ class DefaultBlockConfigApi(Resource):
""" """
Get default block config Get default block config
""" """
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("q", type=str, location="args") parser.add_argument("q", type=str, location="args")
args = parser.parse_args() args = parser.parse_args()
@@ -595,13 +603,12 @@ class ConvertToWorkflowApi(Resource):
Convert expert mode of chatbot app to workflow mode Convert expert mode of chatbot app to workflow mode
Convert Completion App to Workflow App Convert Completion App to Workflow App
""" """
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
if request.data: if request.data:
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=False, nullable=True, location="json") parser.add_argument("name", type=str, required=False, nullable=True, location="json")
@@ -645,6 +652,9 @@ class PublishedAllWorkflowApi(Resource):
""" """
Get published workflows Get published workflows
""" """
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
@@ -693,13 +703,12 @@ class WorkflowByIdApi(Resource):
""" """
Update workflow attributes Update workflow attributes
""" """
if not isinstance(current_user, Account):
raise Forbidden()
# Check permission # Check permission
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, location="json") parser.add_argument("marked_name", type=str, required=False, location="json")
parser.add_argument("marked_comment", type=str, required=False, location="json") parser.add_argument("marked_comment", type=str, required=False, location="json")
@@ -750,13 +759,12 @@ class WorkflowByIdApi(Resource):
""" """
Delete workflow Delete workflow
""" """
if not isinstance(current_user, Account):
raise Forbidden()
# Check permission # Check permission
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
workflow_service = WorkflowService() workflow_service = WorkflowService()
# Create a session and manage the transaction # Create a session and manage the transaction

View File

@@ -21,6 +21,7 @@ from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required from libs.login import current_user, login_required
from models import App, AppMode, db from models import App, AppMode, db
from models.account import Account
from models.workflow import WorkflowDraftVariable from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
@@ -135,6 +136,7 @@ def _api_prerequisite(f):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
assert isinstance(current_user, Account)
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
return f(*args, **kwargs) return f(*args, **kwargs)

View File

@@ -6,9 +6,11 @@ from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import current_user from libs.login import current_user
from models import App, AppMode from models import App, AppMode
from models.account import Account
def _load_app_model(app_id: str) -> Optional[App]: def _load_app_model(app_id: str) -> Optional[App]:
assert isinstance(current_user, Account)
app_model = ( app_model = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")

View File

@@ -43,7 +43,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json") parser.add_argument("files", type=list, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
assert current_user is not None
try: try:
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
@@ -76,6 +76,7 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW: if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError() raise NotWorkflowAppError()
assert current_user is not None
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)

View File

@@ -6,7 +6,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from libs.login import current_user, login_required from libs.login import current_user, login_required
from models.account import TenantAccountRole from models.account import Account, TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService from services.model_load_balancing_service import ModelLoadBalancingService
@@ -15,10 +15,12 @@ class LoadBalancingCredentialsValidateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
assert isinstance(current_user, Account)
if not TenantAccountRole.is_privileged_role(current_user.current_role): if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden() raise Forbidden()
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
assert tenant_id is not None
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json") parser.add_argument("model", type=str, required=True, nullable=False, location="json")
@@ -64,10 +66,12 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str, config_id: str): def post(self, provider: str, config_id: str):
assert isinstance(current_user, Account)
if not TenantAccountRole.is_privileged_role(current_user.current_role): if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden() raise Forbidden()
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
assert tenant_id is not None
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json") parser.add_argument("model", type=str, required=True, nullable=False, location="json")

View File

@@ -10,6 +10,7 @@ from controllers.service_api.wraps import validate_app_token
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.annotation_fields import annotation_fields, build_annotation_model from fields.annotation_fields import annotation_fields, build_annotation_model
from libs.login import current_user from libs.login import current_user
from models.account import Account
from models.model import App from models.model import App
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
@@ -163,6 +164,7 @@ class AnnotationUpdateDeleteApi(Resource):
@service_api_ns.marshal_with(build_annotation_model(service_api_ns)) @service_api_ns.marshal_with(build_annotation_model(service_api_ns))
def put(self, app_model: App, annotation_id): def put(self, app_model: App, annotation_id):
"""Update an existing annotation.""" """Update an existing annotation."""
assert isinstance(current_user, Account)
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
@@ -185,6 +187,8 @@ class AnnotationUpdateDeleteApi(Resource):
@validate_app_token @validate_app_token
def delete(self, app_model: App, annotation_id): def delete(self, app_model: App, annotation_id):
"""Delete an annotation.""" """Delete an annotation."""
assert isinstance(current_user, Account)
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()

View File

@@ -18,6 +18,7 @@ from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import build_dataset_tag_fields from fields.tag_fields import build_dataset_tag_fields
from libs.login import current_user from libs.login import current_user
from models.account import Account
from models.dataset import Dataset, DatasetPermissionEnum from models.dataset import Dataset, DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
@@ -213,7 +214,10 @@ class DatasetListApi(DatasetApiResource):
) )
# check embedding setting # check embedding setting
provider_manager = ProviderManager() provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
assert cid is not None
configurations = provider_manager.get_configurations(tenant_id=cid)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@@ -266,6 +270,7 @@ class DatasetListApi(DatasetApiResource):
) )
try: try:
assert isinstance(current_user, Account)
dataset = DatasetService.create_empty_dataset( dataset = DatasetService.create_empty_dataset(
tenant_id=tenant_id, tenant_id=tenant_id,
name=args["name"], name=args["name"],
@@ -319,7 +324,10 @@ class DatasetApi(DatasetApiResource):
# check embedding setting # check embedding setting
provider_manager = ProviderManager() provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
assert cid is not None
configurations = provider_manager.get_configurations(tenant_id=cid)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@@ -391,6 +399,7 @@ class DatasetApi(DatasetApiResource):
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
result_data = marshal(dataset, dataset_detail_fields) result_data = marshal(dataset, dataset_detail_fields)
assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members": if data.get("partial_member_list") and data.get("permission") == "partial_members":
@@ -532,7 +541,10 @@ class DatasetTagsApi(DatasetApiResource):
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
def get(self, _, dataset_id): def get(self, _, dataset_id):
"""Get all knowledge type tags.""" """Get all knowledge type tags."""
tags = TagService.get_tags("knowledge", current_user.current_tenant_id) assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
assert cid is not None
tags = TagService.get_tags("knowledge", cid)
return tags, 200 return tags, 200
@@ -550,6 +562,7 @@ class DatasetTagsApi(DatasetApiResource):
@validate_dataset_token @validate_dataset_token
def post(self, _, dataset_id): def post(self, _, dataset_id):
"""Add a knowledge type tag.""" """Add a knowledge type tag."""
assert isinstance(current_user, Account)
if not (current_user.is_editor or current_user.is_dataset_editor): if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
@@ -573,6 +586,7 @@ class DatasetTagsApi(DatasetApiResource):
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
@validate_dataset_token @validate_dataset_token
def patch(self, _, dataset_id): def patch(self, _, dataset_id):
assert isinstance(current_user, Account)
if not (current_user.is_editor or current_user.is_dataset_editor): if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
@@ -599,6 +613,7 @@ class DatasetTagsApi(DatasetApiResource):
@validate_dataset_token @validate_dataset_token
def delete(self, _, dataset_id): def delete(self, _, dataset_id):
"""Delete a knowledge type tag.""" """Delete a knowledge type tag."""
assert isinstance(current_user, Account)
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
args = tag_delete_parser.parse_args() args = tag_delete_parser.parse_args()
@@ -622,6 +637,7 @@ class DatasetTagBindingApi(DatasetApiResource):
@validate_dataset_token @validate_dataset_token
def post(self, _, dataset_id): def post(self, _, dataset_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
assert isinstance(current_user, Account)
if not (current_user.is_editor or current_user.is_dataset_editor): if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
@@ -647,6 +663,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
@validate_dataset_token @validate_dataset_token
def post(self, _, dataset_id): def post(self, _, dataset_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
assert isinstance(current_user, Account)
if not (current_user.is_editor or current_user.is_dataset_editor): if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
@@ -672,6 +689,8 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
def get(self, _, *args, **kwargs): def get(self, _, *args, **kwargs):
"""Get all knowledge type tags.""" """Get all knowledge type tags."""
dataset_id = kwargs.get("dataset_id") dataset_id = kwargs.get("dataset_id")
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id)) tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags] tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
response = {"data": tags_list, "total": len(tags)} response = {"data": tags_list, "total": len(tags)}

View File

@@ -1,5 +1,5 @@
from functools import wraps from functools import wraps
from typing import Any from typing import Union, cast
from flask import current_app, g, has_request_context, request from flask import current_app, g, has_request_context, request
from flask_login.config import EXEMPT_METHODS # type: ignore from flask_login.config import EXEMPT_METHODS # type: ignore
@@ -11,7 +11,7 @@ from models.model import EndUser
#: A proxy for the current user. If no user is logged in, this will be an #: A proxy for the current user. If no user is logged in, this will be an
#: anonymous user #: anonymous user
current_user: Any = LocalProxy(lambda: _get_user()) current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user()))
def login_required(func): def login_required(func):
@@ -52,7 +52,7 @@ def login_required(func):
def decorated_view(*args, **kwargs): def decorated_view(*args, **kwargs):
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
pass pass
elif not current_user.is_authenticated: elif current_user is not None and not current_user.is_authenticated:
return current_app.login_manager.unauthorized() # type: ignore return current_app.login_manager.unauthorized() # type: ignore
# flask 1.x compatibility # flask 1.x compatibility