diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index cdec92aee..0b3e5eb42 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -17,6 +17,7 @@ from core.workflow.entities.workflow_execution import ( ) from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from libs.helper import extract_tenant_id from models import ( Account, CreatorUserRole, @@ -67,7 +68,7 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): ) # Extract tenant_id from user - tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id + tenant_id = extract_tenant_id(user) if not tenant_id: raise ValueError("User must have a tenant_id or current_tenant_id") self._tenant_id = tenant_id diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 797cce935..a5feeb0d7 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -20,6 +20,7 @@ from core.workflow.entities.workflow_node_execution import ( from core.workflow.nodes.enums import NodeType from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from libs.helper import extract_tenant_id from models import ( Account, CreatorUserRole, @@ -70,7 +71,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) ) # Extract tenant_id from user - tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id + tenant_id = extract_tenant_id(user) if not tenant_id: raise ValueError("User must have a tenant_id or current_tenant_id") self._tenant_id = tenant_id diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index 23cf4c5ca..b62b0b60d 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -12,6 +12,7 @@ from flask_login import user_loaded_from_request, user_logged_in # type: ignore from configs import dify_config from dify_app import DifyApp +from libs.helper import extract_tenant_id from models import Account, EndUser @@ -24,11 +25,8 @@ def on_user_loaded(_sender, user: Union["Account", "EndUser"]): if user: try: current_span = get_current_span() - if isinstance(user, Account) and user.current_tenant_id: - tenant_id = user.current_tenant_id - elif isinstance(user, EndUser): - tenant_id = user.tenant_id - else: + tenant_id = extract_tenant_id(user) + if not tenant_id: return if current_span: current_span.set_attribute("service.tenant.id", tenant_id) diff --git a/api/libs/helper.py b/api/libs/helper.py index 3f2a63095..48126461a 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -25,6 +25,31 @@ from extensions.ext_redis import redis_client if TYPE_CHECKING: from models.account import Account + from models.model import EndUser + + +def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None: + """ + Extract tenant_id from Account or EndUser object. + + Args: + user: Account or EndUser object + + Returns: + tenant_id string if available, None otherwise + + Raises: + ValueError: If user is neither Account nor EndUser + """ + from models.account import Account + from models.model import EndUser + + if isinstance(user, Account): + return user.current_tenant_id + elif isinstance(user, EndUser): + return user.tenant_id + else: + raise ValueError(f"Invalid user type: {type(user)}. Expected Account or EndUser.") def run(script): diff --git a/api/models/workflow.py b/api/models/workflow.py index 7f01135af..77d48bec4 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -15,6 +15,7 @@ from core.variables import utils as variable_utils from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.nodes.enums import NodeType from factories.variable_factory import TypeMismatchError, build_segment_with_type +from libs.helper import extract_tenant_id from ._workflow_exc import NodeNotFoundError, WorkflowDataError @@ -352,12 +353,7 @@ class Workflow(Base): self._environment_variables = "{}" # Get tenant_id from current_user (Account or EndUser) - if isinstance(current_user, Account): - # Account user - tenant_id = current_user.current_tenant_id - else: - # EndUser - tenant_id = current_user.tenant_id + tenant_id = extract_tenant_id(current_user) if not tenant_id: return [] @@ -384,12 +380,7 @@ class Workflow(Base): return # Get tenant_id from current_user (Account or EndUser) - if isinstance(current_user, Account): - # Account user - tenant_id = current_user.current_tenant_id - else: - # EndUser - tenant_id = current_user.tenant_id + tenant_id = extract_tenant_id(current_user) if not tenant_id: self._environment_variables = "{}" diff --git a/api/services/file_service.py b/api/services/file_service.py index 2d68f30c5..286535bd1 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -18,6 +18,7 @@ from core.file import helpers as file_helpers from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db from extensions.ext_storage import storage +from libs.helper import extract_tenant_id from models.account import Account from models.enums import CreatorUserRole from models.model import EndUser, UploadFile @@ -61,11 +62,7 @@ class FileService: # generate file key file_uuid = str(uuid.uuid4()) - if isinstance(user, Account): - current_tenant_id = user.current_tenant_id - else: - # end_user - current_tenant_id = user.tenant_id + current_tenant_id = extract_tenant_id(user) file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension diff --git a/api/tests/unit_tests/libs/test_helper.py b/api/tests/unit_tests/libs/test_helper.py new file mode 100644 index 000000000..b7701055f --- /dev/null +++ b/api/tests/unit_tests/libs/test_helper.py @@ -0,0 +1,65 @@ +import pytest + +from libs.helper import extract_tenant_id +from models.account import Account +from models.model import EndUser + + +class TestExtractTenantId: + """Test cases for the extract_tenant_id utility function.""" + + def test_extract_tenant_id_from_account_with_tenant(self): + """Test extracting tenant_id from Account with current_tenant_id.""" + # Create a mock Account object + account = Account() + # Mock the current_tenant_id property + account._current_tenant = type("MockTenant", (), {"id": "account-tenant-123"})() + + tenant_id = extract_tenant_id(account) + assert tenant_id == "account-tenant-123" + + def test_extract_tenant_id_from_account_without_tenant(self): + """Test extracting tenant_id from Account without current_tenant_id.""" + # Create a mock Account object + account = Account() + account._current_tenant = None + + tenant_id = extract_tenant_id(account) + assert tenant_id is None + + def test_extract_tenant_id_from_enduser_with_tenant(self): + """Test extracting tenant_id from EndUser with tenant_id.""" + # Create a mock EndUser object + end_user = EndUser() + end_user.tenant_id = "enduser-tenant-456" + + tenant_id = extract_tenant_id(end_user) + assert tenant_id == "enduser-tenant-456" + + def test_extract_tenant_id_from_enduser_without_tenant(self): + """Test extracting tenant_id from EndUser without tenant_id.""" + # Create a mock EndUser object + end_user = EndUser() + end_user.tenant_id = None + + tenant_id = extract_tenant_id(end_user) + assert tenant_id is None + + def test_extract_tenant_id_with_invalid_user_type(self): + """Test extracting tenant_id with invalid user type raises ValueError.""" + invalid_user = "not_a_user_object" + + with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"): + extract_tenant_id(invalid_user) + + def test_extract_tenant_id_with_none_user(self): + """Test extracting tenant_id with None user raises ValueError.""" + with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"): + extract_tenant_id(None) + + def test_extract_tenant_id_with_dict_user(self): + """Test extracting tenant_id with dict user raises ValueError.""" + dict_user = {"id": "123", "tenant_id": "456"} + + with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"): + extract_tenant_id(dict_user) diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index 69163d48b..5bc77ad0e 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -9,6 +9,7 @@ from core.file.models import File from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable from core.variables.segments import IntegerSegment, Segment from factories.variable_factory import build_segment +from models.model import EndUser from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable @@ -43,7 +44,7 @@ def test_environment_variables(): ) # Mock current_user as an EndUser - mock_user = mock.Mock() + mock_user = mock.Mock(spec=EndUser) mock_user.tenant_id = "tenant_id" with ( @@ -90,7 +91,7 @@ def test_update_environment_variables(): ) # Mock current_user as an EndUser - mock_user = mock.Mock() + mock_user = mock.Mock(spec=EndUser) mock_user.tenant_id = "tenant_id" with ( @@ -136,7 +137,7 @@ def test_to_dict(): # Create some EnvironmentVariable instances # Mock current_user as an EndUser - mock_user = mock.Mock() + mock_user = mock.Mock(spec=EndUser) mock_user.tenant_id = "tenant_id" with (