feat(libs): Introduce extract_tenant_id (#22086)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2025-07-09 17:45:56 +08:00
committed by GitHub
parent 1885426421
commit 4cb50f1809
8 changed files with 106 additions and 27 deletions

View File

@@ -17,6 +17,7 @@ from core.workflow.entities.workflow_execution import (
) )
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from libs.helper import extract_tenant_id
from models import ( from models import (
Account, Account,
CreatorUserRole, CreatorUserRole,
@@ -67,7 +68,7 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
) )
# Extract tenant_id from user # Extract tenant_id from user
tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id tenant_id = extract_tenant_id(user)
if not tenant_id: if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id") raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id self._tenant_id = tenant_id

View File

@@ -20,6 +20,7 @@ from core.workflow.entities.workflow_node_execution import (
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from libs.helper import extract_tenant_id
from models import ( from models import (
Account, Account,
CreatorUserRole, CreatorUserRole,
@@ -70,7 +71,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
) )
# Extract tenant_id from user # Extract tenant_id from user
tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id tenant_id = extract_tenant_id(user)
if not tenant_id: if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id") raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id self._tenant_id = tenant_id

View File

@@ -12,6 +12,7 @@ from flask_login import user_loaded_from_request, user_logged_in # type: ignore
from configs import dify_config from configs import dify_config
from dify_app import DifyApp from dify_app import DifyApp
from libs.helper import extract_tenant_id
from models import Account, EndUser from models import Account, EndUser
@@ -24,11 +25,8 @@ def on_user_loaded(_sender, user: Union["Account", "EndUser"]):
if user: if user:
try: try:
current_span = get_current_span() current_span = get_current_span()
if isinstance(user, Account) and user.current_tenant_id: tenant_id = extract_tenant_id(user)
tenant_id = user.current_tenant_id if not tenant_id:
elif isinstance(user, EndUser):
tenant_id = user.tenant_id
else:
return return
if current_span: if current_span:
current_span.set_attribute("service.tenant.id", tenant_id) current_span.set_attribute("service.tenant.id", tenant_id)

View File

@@ -25,6 +25,31 @@ from extensions.ext_redis import redis_client
if TYPE_CHECKING: if TYPE_CHECKING:
from models.account import Account from models.account import Account
from models.model import EndUser
def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None:
"""
Extract tenant_id from Account or EndUser object.
Args:
user: Account or EndUser object
Returns:
tenant_id string if available, None otherwise
Raises:
ValueError: If user is neither Account nor EndUser
"""
from models.account import Account
from models.model import EndUser
if isinstance(user, Account):
return user.current_tenant_id
elif isinstance(user, EndUser):
return user.tenant_id
else:
raise ValueError(f"Invalid user type: {type(user)}. Expected Account or EndUser.")
def run(script): def run(script):

View File

@@ -15,6 +15,7 @@ from core.variables import utils as variable_utils
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from factories.variable_factory import TypeMismatchError, build_segment_with_type from factories.variable_factory import TypeMismatchError, build_segment_with_type
from libs.helper import extract_tenant_id
from ._workflow_exc import NodeNotFoundError, WorkflowDataError from ._workflow_exc import NodeNotFoundError, WorkflowDataError
@@ -352,12 +353,7 @@ class Workflow(Base):
self._environment_variables = "{}" self._environment_variables = "{}"
# Get tenant_id from current_user (Account or EndUser) # Get tenant_id from current_user (Account or EndUser)
if isinstance(current_user, Account): tenant_id = extract_tenant_id(current_user)
# Account user
tenant_id = current_user.current_tenant_id
else:
# EndUser
tenant_id = current_user.tenant_id
if not tenant_id: if not tenant_id:
return [] return []
@@ -384,12 +380,7 @@ class Workflow(Base):
return return
# Get tenant_id from current_user (Account or EndUser) # Get tenant_id from current_user (Account or EndUser)
if isinstance(current_user, Account): tenant_id = extract_tenant_id(current_user)
# Account user
tenant_id = current_user.current_tenant_id
else:
# EndUser
tenant_id = current_user.tenant_id
if not tenant_id: if not tenant_id:
self._environment_variables = "{}" self._environment_variables = "{}"

View File

@@ -18,6 +18,7 @@ from core.file import helpers as file_helpers
from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.extractor.extract_processor import ExtractProcessor
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_storage import storage from extensions.ext_storage import storage
from libs.helper import extract_tenant_id
from models.account import Account from models.account import Account
from models.enums import CreatorUserRole from models.enums import CreatorUserRole
from models.model import EndUser, UploadFile from models.model import EndUser, UploadFile
@@ -61,11 +62,7 @@ class FileService:
# generate file key # generate file key
file_uuid = str(uuid.uuid4()) file_uuid = str(uuid.uuid4())
if isinstance(user, Account): current_tenant_id = extract_tenant_id(user)
current_tenant_id = user.current_tenant_id
else:
# end_user
current_tenant_id = user.tenant_id
file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension

View File

@@ -0,0 +1,65 @@
import pytest
from libs.helper import extract_tenant_id
from models.account import Account
from models.model import EndUser
class TestExtractTenantId:
"""Test cases for the extract_tenant_id utility function."""
def test_extract_tenant_id_from_account_with_tenant(self):
"""Test extracting tenant_id from Account with current_tenant_id."""
# Create a mock Account object
account = Account()
# Mock the current_tenant_id property
account._current_tenant = type("MockTenant", (), {"id": "account-tenant-123"})()
tenant_id = extract_tenant_id(account)
assert tenant_id == "account-tenant-123"
def test_extract_tenant_id_from_account_without_tenant(self):
"""Test extracting tenant_id from Account without current_tenant_id."""
# Create a mock Account object
account = Account()
account._current_tenant = None
tenant_id = extract_tenant_id(account)
assert tenant_id is None
def test_extract_tenant_id_from_enduser_with_tenant(self):
"""Test extracting tenant_id from EndUser with tenant_id."""
# Create a mock EndUser object
end_user = EndUser()
end_user.tenant_id = "enduser-tenant-456"
tenant_id = extract_tenant_id(end_user)
assert tenant_id == "enduser-tenant-456"
def test_extract_tenant_id_from_enduser_without_tenant(self):
"""Test extracting tenant_id from EndUser without tenant_id."""
# Create a mock EndUser object
end_user = EndUser()
end_user.tenant_id = None
tenant_id = extract_tenant_id(end_user)
assert tenant_id is None
def test_extract_tenant_id_with_invalid_user_type(self):
"""Test extracting tenant_id with invalid user type raises ValueError."""
invalid_user = "not_a_user_object"
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
extract_tenant_id(invalid_user)
def test_extract_tenant_id_with_none_user(self):
"""Test extracting tenant_id with None user raises ValueError."""
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
extract_tenant_id(None)
def test_extract_tenant_id_with_dict_user(self):
"""Test extracting tenant_id with dict user raises ValueError."""
dict_user = {"id": "123", "tenant_id": "456"}
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
extract_tenant_id(dict_user)

View File

@@ -9,6 +9,7 @@ from core.file.models import File
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
from core.variables.segments import IntegerSegment, Segment from core.variables.segments import IntegerSegment, Segment
from factories.variable_factory import build_segment from factories.variable_factory import build_segment
from models.model import EndUser
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
@@ -43,7 +44,7 @@ def test_environment_variables():
) )
# Mock current_user as an EndUser # Mock current_user as an EndUser
mock_user = mock.Mock() mock_user = mock.Mock(spec=EndUser)
mock_user.tenant_id = "tenant_id" mock_user.tenant_id = "tenant_id"
with ( with (
@@ -90,7 +91,7 @@ def test_update_environment_variables():
) )
# Mock current_user as an EndUser # Mock current_user as an EndUser
mock_user = mock.Mock() mock_user = mock.Mock(spec=EndUser)
mock_user.tenant_id = "tenant_id" mock_user.tenant_id = "tenant_id"
with ( with (
@@ -136,7 +137,7 @@ def test_to_dict():
# Create some EnvironmentVariable instances # Create some EnvironmentVariable instances
# Mock current_user as an EndUser # Mock current_user as an EndUser
mock_user = mock.Mock() mock_user = mock.Mock(spec=EndUser)
mock_user.tenant_id = "tenant_id" mock_user.tenant_id = "tenant_id"
with ( with (