feat(libs): Introduce extract_tenant_id
(#22086)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -17,6 +17,7 @@ from core.workflow.entities.workflow_execution import (
|
||||
)
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
@@ -67,7 +68,7 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
)
|
||||
|
||||
# Extract tenant_id from user
|
||||
tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id
|
||||
tenant_id = extract_tenant_id(user)
|
||||
if not tenant_id:
|
||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||
self._tenant_id = tenant_id
|
||||
|
@@ -20,6 +20,7 @@ from core.workflow.entities.workflow_node_execution import (
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
@@ -70,7 +71,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
)
|
||||
|
||||
# Extract tenant_id from user
|
||||
tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id
|
||||
tenant_id = extract_tenant_id(user)
|
||||
if not tenant_id:
|
||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||
self._tenant_id = tenant_id
|
||||
|
@@ -12,6 +12,7 @@ from flask_login import user_loaded_from_request, user_logged_in # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import Account, EndUser
|
||||
|
||||
|
||||
@@ -24,11 +25,8 @@ def on_user_loaded(_sender, user: Union["Account", "EndUser"]):
|
||||
if user:
|
||||
try:
|
||||
current_span = get_current_span()
|
||||
if isinstance(user, Account) and user.current_tenant_id:
|
||||
tenant_id = user.current_tenant_id
|
||||
elif isinstance(user, EndUser):
|
||||
tenant_id = user.tenant_id
|
||||
else:
|
||||
tenant_id = extract_tenant_id(user)
|
||||
if not tenant_id:
|
||||
return
|
||||
if current_span:
|
||||
current_span.set_attribute("service.tenant.id", tenant_id)
|
||||
|
@@ -25,6 +25,31 @@ from extensions.ext_redis import redis_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None:
|
||||
"""
|
||||
Extract tenant_id from Account or EndUser object.
|
||||
|
||||
Args:
|
||||
user: Account or EndUser object
|
||||
|
||||
Returns:
|
||||
tenant_id string if available, None otherwise
|
||||
|
||||
Raises:
|
||||
ValueError: If user is neither Account nor EndUser
|
||||
"""
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
|
||||
if isinstance(user, Account):
|
||||
return user.current_tenant_id
|
||||
elif isinstance(user, EndUser):
|
||||
return user.tenant_id
|
||||
else:
|
||||
raise ValueError(f"Invalid user type: {type(user)}. Expected Account or EndUser.")
|
||||
|
||||
|
||||
def run(script):
|
||||
|
@@ -15,6 +15,7 @@ from core.variables import utils as variable_utils
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
from libs.helper import extract_tenant_id
|
||||
|
||||
from ._workflow_exc import NodeNotFoundError, WorkflowDataError
|
||||
|
||||
@@ -352,12 +353,7 @@ class Workflow(Base):
|
||||
self._environment_variables = "{}"
|
||||
|
||||
# Get tenant_id from current_user (Account or EndUser)
|
||||
if isinstance(current_user, Account):
|
||||
# Account user
|
||||
tenant_id = current_user.current_tenant_id
|
||||
else:
|
||||
# EndUser
|
||||
tenant_id = current_user.tenant_id
|
||||
tenant_id = extract_tenant_id(current_user)
|
||||
|
||||
if not tenant_id:
|
||||
return []
|
||||
@@ -384,12 +380,7 @@ class Workflow(Base):
|
||||
return
|
||||
|
||||
# Get tenant_id from current_user (Account or EndUser)
|
||||
if isinstance(current_user, Account):
|
||||
# Account user
|
||||
tenant_id = current_user.current_tenant_id
|
||||
else:
|
||||
# EndUser
|
||||
tenant_id = current_user.tenant_id
|
||||
tenant_id = extract_tenant_id(current_user)
|
||||
|
||||
if not tenant_id:
|
||||
self._environment_variables = "{}"
|
||||
|
@@ -18,6 +18,7 @@ from core.file import helpers as file_helpers
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from libs.helper import extract_tenant_id
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import EndUser, UploadFile
|
||||
@@ -61,11 +62,7 @@ class FileService:
|
||||
# generate file key
|
||||
file_uuid = str(uuid.uuid4())
|
||||
|
||||
if isinstance(user, Account):
|
||||
current_tenant_id = user.current_tenant_id
|
||||
else:
|
||||
# end_user
|
||||
current_tenant_id = user.tenant_id
|
||||
current_tenant_id = extract_tenant_id(user)
|
||||
|
||||
file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension
|
||||
|
||||
|
65
api/tests/unit_tests/libs/test_helper.py
Normal file
65
api/tests/unit_tests/libs/test_helper.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import pytest
|
||||
|
||||
from libs.helper import extract_tenant_id
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
class TestExtractTenantId:
|
||||
"""Test cases for the extract_tenant_id utility function."""
|
||||
|
||||
def test_extract_tenant_id_from_account_with_tenant(self):
|
||||
"""Test extracting tenant_id from Account with current_tenant_id."""
|
||||
# Create a mock Account object
|
||||
account = Account()
|
||||
# Mock the current_tenant_id property
|
||||
account._current_tenant = type("MockTenant", (), {"id": "account-tenant-123"})()
|
||||
|
||||
tenant_id = extract_tenant_id(account)
|
||||
assert tenant_id == "account-tenant-123"
|
||||
|
||||
def test_extract_tenant_id_from_account_without_tenant(self):
|
||||
"""Test extracting tenant_id from Account without current_tenant_id."""
|
||||
# Create a mock Account object
|
||||
account = Account()
|
||||
account._current_tenant = None
|
||||
|
||||
tenant_id = extract_tenant_id(account)
|
||||
assert tenant_id is None
|
||||
|
||||
def test_extract_tenant_id_from_enduser_with_tenant(self):
|
||||
"""Test extracting tenant_id from EndUser with tenant_id."""
|
||||
# Create a mock EndUser object
|
||||
end_user = EndUser()
|
||||
end_user.tenant_id = "enduser-tenant-456"
|
||||
|
||||
tenant_id = extract_tenant_id(end_user)
|
||||
assert tenant_id == "enduser-tenant-456"
|
||||
|
||||
def test_extract_tenant_id_from_enduser_without_tenant(self):
|
||||
"""Test extracting tenant_id from EndUser without tenant_id."""
|
||||
# Create a mock EndUser object
|
||||
end_user = EndUser()
|
||||
end_user.tenant_id = None
|
||||
|
||||
tenant_id = extract_tenant_id(end_user)
|
||||
assert tenant_id is None
|
||||
|
||||
def test_extract_tenant_id_with_invalid_user_type(self):
|
||||
"""Test extracting tenant_id with invalid user type raises ValueError."""
|
||||
invalid_user = "not_a_user_object"
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
|
||||
extract_tenant_id(invalid_user)
|
||||
|
||||
def test_extract_tenant_id_with_none_user(self):
|
||||
"""Test extracting tenant_id with None user raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
|
||||
extract_tenant_id(None)
|
||||
|
||||
def test_extract_tenant_id_with_dict_user(self):
|
||||
"""Test extracting tenant_id with dict user raises ValueError."""
|
||||
dict_user = {"id": "123", "tenant_id": "456"}
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
|
||||
extract_tenant_id(dict_user)
|
@@ -9,6 +9,7 @@ from core.file.models import File
|
||||
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
|
||||
from core.variables.segments import IntegerSegment, Segment
|
||||
from factories.variable_factory import build_segment
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
|
||||
|
||||
|
||||
@@ -43,7 +44,7 @@ def test_environment_variables():
|
||||
)
|
||||
|
||||
# Mock current_user as an EndUser
|
||||
mock_user = mock.Mock()
|
||||
mock_user = mock.Mock(spec=EndUser)
|
||||
mock_user.tenant_id = "tenant_id"
|
||||
|
||||
with (
|
||||
@@ -90,7 +91,7 @@ def test_update_environment_variables():
|
||||
)
|
||||
|
||||
# Mock current_user as an EndUser
|
||||
mock_user = mock.Mock()
|
||||
mock_user = mock.Mock(spec=EndUser)
|
||||
mock_user.tenant_id = "tenant_id"
|
||||
|
||||
with (
|
||||
@@ -136,7 +137,7 @@ def test_to_dict():
|
||||
# Create some EnvironmentVariable instances
|
||||
|
||||
# Mock current_user as an EndUser
|
||||
mock_user = mock.Mock()
|
||||
mock_user = mock.Mock(spec=EndUser)
|
||||
mock_user.tenant_id = "tenant_id"
|
||||
|
||||
with (
|
||||
|
Reference in New Issue
Block a user