diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 8dcffb166..e840c0028 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -72,6 +72,7 @@ class DraftWorkflowApi(Resource): Get draft workflow """ # The role of the current user in the ta table must be admin, owner, or editor + assert isinstance(current_user, Account) if not current_user.is_editor: raise Forbidden() @@ -94,6 +95,7 @@ class DraftWorkflowApi(Resource): Sync draft workflow """ # The role of the current user in the ta table must be admin, owner, or editor + assert isinstance(current_user, Account) if not current_user.is_editor: raise Forbidden() @@ -171,6 +173,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource): Run draft workflow """ # The role of the current user in the ta table must be admin, owner, or editor + assert isinstance(current_user, Account) if not current_user.is_editor: raise Forbidden() @@ -218,13 +221,12 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): """ Run draft workflow iteration node """ + if not isinstance(current_user, Account): + raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): - raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, location="json") args = parser.parse_args() @@ -256,11 +258,10 @@ class WorkflowDraftRunIterationNodeApi(Resource): Run draft workflow iteration node """ # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - if not isinstance(current_user, Account): raise Forbidden() + if not current_user.is_editor: + raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, location="json") @@ -292,12 +293,12 @@ class AdvancedChatDraftRunLoopNodeApi(Resource): """ Run draft workflow loop node """ - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() if not isinstance(current_user, Account): raise Forbidden() + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, location="json") @@ -329,12 +330,12 @@ class WorkflowDraftRunLoopNodeApi(Resource): """ Run draft workflow loop node """ - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() if not isinstance(current_user, Account): raise Forbidden() + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, location="json") @@ -366,12 +367,12 @@ class DraftWorkflowRunApi(Resource): """ Run draft workflow """ - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() if not isinstance(current_user, Account): raise Forbidden() + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") @@ -405,6 +406,9 @@ class WorkflowTaskStopApi(Resource): """ Stop workflow task """ + + if not isinstance(current_user, Account): + raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() @@ -424,12 +428,12 @@ class DraftWorkflowNodeRunApi(Resource): """ Run draft workflow node """ - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() if not isinstance(current_user, Account): raise Forbidden() + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") @@ -472,6 +476,9 @@ class PublishedWorkflowApi(Resource): """ Get published workflow """ + + if not isinstance(current_user, Account): + raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() @@ -491,13 +498,12 @@ class PublishedWorkflowApi(Resource): """ Publish workflow """ + if not isinstance(current_user, Account): + raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): - raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("marked_name", type=str, required=False, default="", location="json") parser.add_argument("marked_comment", type=str, required=False, default="", location="json") @@ -541,6 +547,9 @@ class DefaultBlockConfigsApi(Resource): """ Get default block config """ + + if not isinstance(current_user, Account): + raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() @@ -559,13 +568,12 @@ class DefaultBlockConfigApi(Resource): """ Get default block config """ + if not isinstance(current_user, Account): + raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): - raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("q", type=str, location="args") args = parser.parse_args() @@ -595,13 +603,12 @@ class ConvertToWorkflowApi(Resource): Convert expert mode of chatbot app to workflow mode Convert Completion App to Workflow App """ + if not isinstance(current_user, Account): + raise Forbidden() # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): - raise Forbidden() - if request.data: parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=False, nullable=True, location="json") @@ -645,6 +652,9 @@ class PublishedAllWorkflowApi(Resource): """ Get published workflows """ + + if not isinstance(current_user, Account): + raise Forbidden() if not current_user.is_editor: raise Forbidden() @@ -693,13 +703,12 @@ class WorkflowByIdApi(Resource): """ Update workflow attributes """ + if not isinstance(current_user, Account): + raise Forbidden() # Check permission if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): - raise Forbidden() - parser = reqparse.RequestParser() parser.add_argument("marked_name", type=str, required=False, location="json") parser.add_argument("marked_comment", type=str, required=False, location="json") @@ -750,13 +759,12 @@ class WorkflowByIdApi(Resource): """ Delete workflow """ + if not isinstance(current_user, Account): + raise Forbidden() # Check permission if not current_user.is_editor: raise Forbidden() - if not isinstance(current_user, Account): - raise Forbidden() - workflow_service = WorkflowService() # Create a session and manage the transaction diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 4e625db24..a0b73f7e0 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -21,6 +21,7 @@ from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type from libs.login import current_user, login_required from models import App, AppMode, db +from models.account import Account from models.workflow import WorkflowDraftVariable from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService from services.workflow_service import WorkflowService @@ -135,6 +136,7 @@ def _api_prerequisite(f): @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) def wrapper(*args, **kwargs): + assert isinstance(current_user, Account) if not current_user.is_editor: raise Forbidden() return f(*args, **kwargs) diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 132dc1f96..c7e300279 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -6,9 +6,11 @@ from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from libs.login import current_user from models import App, AppMode +from models.account import Account def _load_app_model(app_id: str) -> Optional[App]: + assert isinstance(current_user, Account) app_model = ( db.session.query(App) .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 3d872fc1f..c1848ceed 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -43,7 +43,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource): parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() - + assert current_user is not None try: response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True @@ -76,6 +76,7 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource): app_mode = AppMode.value_of(app_model.mode) if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() + assert current_user is not None AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 2a54511bf..7c1bc7c07 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -6,7 +6,7 @@ from controllers.console.wraps import account_initialization_required, setup_req from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from libs.login import current_user, login_required -from models.account import TenantAccountRole +from models.account import Account, TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService @@ -15,10 +15,12 @@ class LoadBalancingCredentialsValidateApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + assert isinstance(current_user, Account) if not TenantAccountRole.is_privileged_role(current_user.current_role): raise Forbidden() tenant_id = current_user.current_tenant_id + assert tenant_id is not None parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") @@ -64,10 +66,12 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): @login_required @account_initialization_required def post(self, provider: str, config_id: str): + assert isinstance(current_user, Account) if not TenantAccountRole.is_privileged_role(current_user.current_role): raise Forbidden() tenant_id = current_user.current_tenant_id + assert tenant_id is not None parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 6bc94af8c..9038bda11 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -10,6 +10,7 @@ from controllers.service_api.wraps import validate_app_token from extensions.ext_redis import redis_client from fields.annotation_fields import annotation_fields, build_annotation_model from libs.login import current_user +from models.account import Account from models.model import App from services.annotation_service import AppAnnotationService @@ -163,6 +164,7 @@ class AnnotationUpdateDeleteApi(Resource): @service_api_ns.marshal_with(build_annotation_model(service_api_ns)) def put(self, app_model: App, annotation_id): """Update an existing annotation.""" + assert isinstance(current_user, Account) if not current_user.is_editor: raise Forbidden() @@ -185,6 +187,8 @@ class AnnotationUpdateDeleteApi(Resource): @validate_app_token def delete(self, app_model: App, annotation_id): """Delete an annotation.""" + assert isinstance(current_user, Account) + if not current_user.is_editor: raise Forbidden() diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index c486b0480..7b74c961b 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -18,6 +18,7 @@ from core.provider_manager import ProviderManager from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import build_dataset_tag_fields from libs.login import current_user +from models.account import Account from models.dataset import Dataset, DatasetPermissionEnum from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import RetrievalModel @@ -213,7 +214,10 @@ class DatasetListApi(DatasetApiResource): ) # check embedding setting provider_manager = ProviderManager() - configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) + assert isinstance(current_user, Account) + cid = current_user.current_tenant_id + assert cid is not None + configurations = provider_manager.get_configurations(tenant_id=cid) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -266,6 +270,7 @@ class DatasetListApi(DatasetApiResource): ) try: + assert isinstance(current_user, Account) dataset = DatasetService.create_empty_dataset( tenant_id=tenant_id, name=args["name"], @@ -319,7 +324,10 @@ class DatasetApi(DatasetApiResource): # check embedding setting provider_manager = ProviderManager() - configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) + assert isinstance(current_user, Account) + cid = current_user.current_tenant_id + assert cid is not None + configurations = provider_manager.get_configurations(tenant_id=cid) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -391,6 +399,7 @@ class DatasetApi(DatasetApiResource): raise NotFound("Dataset not found.") result_data = marshal(dataset, dataset_detail_fields) + assert isinstance(current_user, Account) tenant_id = current_user.current_tenant_id if data.get("partial_member_list") and data.get("permission") == "partial_members": @@ -532,7 +541,10 @@ class DatasetTagsApi(DatasetApiResource): @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) def get(self, _, dataset_id): """Get all knowledge type tags.""" - tags = TagService.get_tags("knowledge", current_user.current_tenant_id) + assert isinstance(current_user, Account) + cid = current_user.current_tenant_id + assert cid is not None + tags = TagService.get_tags("knowledge", cid) return tags, 200 @@ -550,6 +562,7 @@ class DatasetTagsApi(DatasetApiResource): @validate_dataset_token def post(self, _, dataset_id): """Add a knowledge type tag.""" + assert isinstance(current_user, Account) if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() @@ -573,6 +586,7 @@ class DatasetTagsApi(DatasetApiResource): @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) @validate_dataset_token def patch(self, _, dataset_id): + assert isinstance(current_user, Account) if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() @@ -599,6 +613,7 @@ class DatasetTagsApi(DatasetApiResource): @validate_dataset_token def delete(self, _, dataset_id): """Delete a knowledge type tag.""" + assert isinstance(current_user, Account) if not current_user.is_editor: raise Forbidden() args = tag_delete_parser.parse_args() @@ -622,6 +637,7 @@ class DatasetTagBindingApi(DatasetApiResource): @validate_dataset_token def post(self, _, dataset_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator + assert isinstance(current_user, Account) if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() @@ -647,6 +663,7 @@ class DatasetTagUnbindingApi(DatasetApiResource): @validate_dataset_token def post(self, _, dataset_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator + assert isinstance(current_user, Account) if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() @@ -672,6 +689,8 @@ class DatasetTagsBindingStatusApi(DatasetApiResource): def get(self, _, *args, **kwargs): """Get all knowledge type tags.""" dataset_id = kwargs.get("dataset_id") + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id)) tags_list = [{"id": tag.id, "name": tag.name} for tag in tags] response = {"data": tags_list, "total": len(tags)} diff --git a/api/libs/login.py b/api/libs/login.py index e3a7fe294..711d16e3b 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -1,5 +1,5 @@ from functools import wraps -from typing import Any +from typing import Union, cast from flask import current_app, g, has_request_context, request from flask_login.config import EXEMPT_METHODS # type: ignore @@ -11,7 +11,7 @@ from models.model import EndUser #: A proxy for the current user. If no user is logged in, this will be an #: anonymous user -current_user: Any = LocalProxy(lambda: _get_user()) +current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user())) def login_required(func): @@ -52,7 +52,7 @@ def login_required(func): def decorated_view(*args, **kwargs): if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: pass - elif not current_user.is_authenticated: + elif current_user is not None and not current_user.is_authenticated: return current_app.login_manager.unauthorized() # type: ignore # flask 1.x compatibility