feat: add testcontainers based tests for model loadbalancing service (#24066)

This commit is contained in:
NeatGuyCoding
2025-08-18 09:54:22 +08:00
committed by GitHub
parent 97b24f48d5
commit 80f0594f4b

View File

@@ -0,0 +1,474 @@
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from models.account import TenantAccountJoin, TenantAccountRole
from models.model import Account, Tenant
from models.provider import LoadBalancingModelConfig, Provider, ProviderModelSetting
from services.model_load_balancing_service import ModelLoadBalancingService
class TestModelLoadBalancingService:
"""Integration tests for ModelLoadBalancingService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.model_load_balancing_service.ProviderManager") as mock_provider_manager,
patch("services.model_load_balancing_service.LBModelManager") as mock_lb_model_manager,
patch("services.model_load_balancing_service.ModelProviderFactory") as mock_model_provider_factory,
patch("services.model_load_balancing_service.encrypter") as mock_encrypter,
):
# Setup default mock returns
mock_provider_manager_instance = mock_provider_manager.return_value
# Mock provider configuration
mock_provider_config = MagicMock()
mock_provider_config.provider.provider = "openai"
mock_provider_config.custom_configuration.provider = None
# Mock provider model setting
mock_provider_model_setting = MagicMock()
mock_provider_model_setting.load_balancing_enabled = False
mock_provider_config.get_provider_model_setting.return_value = mock_provider_model_setting
# Mock provider configurations dict
mock_provider_configs = {"openai": mock_provider_config}
mock_provider_manager_instance.get_configurations.return_value = mock_provider_configs
# Mock LBModelManager
mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0)
# Mock ModelProviderFactory
mock_model_provider_factory_instance = mock_model_provider_factory.return_value
# Mock credential schemas
mock_credential_schema = MagicMock()
mock_credential_schema.credential_form_schemas = []
# Mock provider configuration methods
mock_provider_config.extract_secret_variables.return_value = []
mock_provider_config.obfuscated_credentials.return_value = {}
mock_provider_config._get_credential_schema.return_value = mock_credential_schema
yield {
"provider_manager": mock_provider_manager,
"lb_model_manager": mock_lb_model_manager,
"model_provider_factory": mock_model_provider_factory,
"encrypter": mock_encrypter,
"provider_config": mock_provider_config,
"provider_model_setting": mock_provider_model_setting,
"credential_schema": mock_credential_schema,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (account, tenant) - Created account and tenant instances
"""
fake = Faker()
# Create account
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
current=True,
)
db.session.add(join)
db.session.commit()
# Set current tenant for account
account.current_tenant = tenant
return account, tenant
def _create_test_provider_and_setting(
self, db_session_with_containers, tenant_id, mock_external_service_dependencies
):
"""
Helper method to create a test provider and provider model setting.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
tenant_id: Tenant ID for the provider
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (provider, provider_model_setting) - Created provider and setting instances
"""
fake = Faker()
from extensions.ext_database import db
# Create provider
provider = Provider(
tenant_id=tenant_id,
provider_name="openai",
provider_type="custom",
is_valid=True,
)
db.session.add(provider)
db.session.commit()
# Create provider model setting
provider_model_setting = ProviderModelSetting(
tenant_id=tenant_id,
provider_name="openai",
model_name="gpt-3.5-turbo",
model_type="text-generation", # Use the origin model type that matches the query
enabled=True,
load_balancing_enabled=False,
)
db.session.add(provider_model_setting)
db.session.commit()
return provider, provider_model_setting
def test_enable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful model load balancing enablement.
This test verifies:
- Proper provider configuration retrieval
- Successful enablement of model load balancing
- Correct method calls to provider configuration
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider, provider_model_setting = self._create_test_provider_and_setting(
db_session_with_containers, tenant.id, mock_external_service_dependencies
)
# Setup mocks for enable method
mock_provider_config = mock_external_service_dependencies["provider_config"]
mock_provider_config.enable_model_load_balancing = MagicMock()
# Act: Execute the method under test
service = ModelLoadBalancingService()
service.enable_model_load_balancing(
tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm"
)
# Assert: Verify the expected outcomes
mock_provider_config.enable_model_load_balancing.assert_called_once()
call_args = mock_provider_config.enable_model_load_balancing.call_args
assert call_args.kwargs["model"] == "gpt-3.5-turbo"
assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value
# Verify database state
from extensions.ext_database import db
db.session.refresh(provider)
db.session.refresh(provider_model_setting)
assert provider.id is not None
assert provider_model_setting.id is not None
def test_disable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful model load balancing disablement.
This test verifies:
- Proper provider configuration retrieval
- Successful disablement of model load balancing
- Correct method calls to provider configuration
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider, provider_model_setting = self._create_test_provider_and_setting(
db_session_with_containers, tenant.id, mock_external_service_dependencies
)
# Setup mocks for disable method
mock_provider_config = mock_external_service_dependencies["provider_config"]
mock_provider_config.disable_model_load_balancing = MagicMock()
# Act: Execute the method under test
service = ModelLoadBalancingService()
service.disable_model_load_balancing(
tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm"
)
# Assert: Verify the expected outcomes
mock_provider_config.disable_model_load_balancing.assert_called_once()
call_args = mock_provider_config.disable_model_load_balancing.call_args
assert call_args.kwargs["model"] == "gpt-3.5-turbo"
assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value
# Verify database state
from extensions.ext_database import db
db.session.refresh(provider)
db.session.refresh(provider_model_setting)
assert provider.id is not None
assert provider_model_setting.id is not None
def test_enable_model_load_balancing_provider_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error handling when provider does not exist.
This test verifies:
- Proper error handling for non-existent provider
- Correct exception type and message
- No database state changes
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Setup mocks to return empty provider configurations
mock_provider_manager = mock_external_service_dependencies["provider_manager"]
mock_provider_manager_instance = mock_provider_manager.return_value
mock_provider_manager_instance.get_configurations.return_value = {}
# Act & Assert: Verify proper error handling
service = ModelLoadBalancingService()
with pytest.raises(ValueError) as exc_info:
service.enable_model_load_balancing(
tenant_id=tenant.id, provider="nonexistent_provider", model="gpt-3.5-turbo", model_type="llm"
)
# Verify correct error message
assert "Provider nonexistent_provider does not exist." in str(exc_info.value)
# Verify no database state changes occurred
from extensions.ext_database import db
db.session.rollback()
def test_get_load_balancing_configs_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful retrieval of load balancing configurations.
This test verifies:
- Proper provider configuration retrieval
- Successful database query for load balancing configs
- Correct return format and data structure
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider, provider_model_setting = self._create_test_provider_and_setting(
db_session_with_containers, tenant.id, mock_external_service_dependencies
)
# Create load balancing config
from extensions.ext_database import db
load_balancing_config = LoadBalancingModelConfig(
tenant_id=tenant.id,
provider_name="openai",
model_name="gpt-3.5-turbo",
model_type="text-generation", # Use the origin model type that matches the query
name="config1",
encrypted_config='{"api_key": "test_key"}',
enabled=True,
)
db.session.add(load_balancing_config)
db.session.commit()
# Verify the config was created
db.session.refresh(load_balancing_config)
assert load_balancing_config.id is not None
# Setup mocks for get_load_balancing_configs method
mock_provider_config = mock_external_service_dependencies["provider_config"]
mock_provider_model_setting = mock_external_service_dependencies["provider_model_setting"]
mock_provider_model_setting.load_balancing_enabled = True
# Mock credential schema methods
mock_credential_schema = mock_external_service_dependencies["credential_schema"]
mock_credential_schema.credential_form_schemas = []
# Mock encrypter
mock_encrypter = mock_external_service_dependencies["encrypter"]
mock_encrypter.get_decrypt_decoding.return_value = ("key", "cipher")
# Mock _get_credential_schema method
mock_provider_config._get_credential_schema.return_value = mock_credential_schema
# Mock extract_secret_variables method
mock_provider_config.extract_secret_variables.return_value = []
# Mock obfuscated_credentials method
mock_provider_config.obfuscated_credentials.return_value = {}
# Mock LBModelManager.get_config_in_cooldown_and_ttl
mock_lb_model_manager = mock_external_service_dependencies["lb_model_manager"]
mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0)
# Act: Execute the method under test
service = ModelLoadBalancingService()
is_enabled, configs = service.get_load_balancing_configs(
tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm"
)
# Assert: Verify the expected outcomes
assert is_enabled is True
assert len(configs) == 1
assert configs[0]["id"] == load_balancing_config.id
assert configs[0]["name"] == "config1"
assert configs[0]["enabled"] is True
assert configs[0]["in_cooldown"] is False
assert configs[0]["ttl"] == 0
# Verify database state
db.session.refresh(load_balancing_config)
assert load_balancing_config.id is not None
def test_get_load_balancing_configs_provider_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error handling when provider does not exist in get_load_balancing_configs.
This test verifies:
- Proper error handling for non-existent provider
- Correct exception type and message
- No database state changes
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Setup mocks to return empty provider configurations
mock_provider_manager = mock_external_service_dependencies["provider_manager"]
mock_provider_manager_instance = mock_provider_manager.return_value
mock_provider_manager_instance.get_configurations.return_value = {}
# Act & Assert: Verify proper error handling
service = ModelLoadBalancingService()
with pytest.raises(ValueError) as exc_info:
service.get_load_balancing_configs(
tenant_id=tenant.id, provider="nonexistent_provider", model="gpt-3.5-turbo", model_type="llm"
)
# Verify correct error message
assert "Provider nonexistent_provider does not exist." in str(exc_info.value)
# Verify no database state changes occurred
from extensions.ext_database import db
db.session.rollback()
def test_get_load_balancing_configs_with_inherit_config(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test load balancing configs retrieval with inherit configuration.
This test verifies:
- Proper handling of inherit configuration
- Correct ordering of configurations
- Inherit config initialization when needed
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider, provider_model_setting = self._create_test_provider_and_setting(
db_session_with_containers, tenant.id, mock_external_service_dependencies
)
# Create load balancing config
from extensions.ext_database import db
load_balancing_config = LoadBalancingModelConfig(
tenant_id=tenant.id,
provider_name="openai",
model_name="gpt-3.5-turbo",
model_type="text-generation", # Use the origin model type that matches the query
name="config1",
encrypted_config='{"api_key": "test_key"}',
enabled=True,
)
db.session.add(load_balancing_config)
db.session.commit()
# Setup mocks for inherit config scenario
mock_provider_config = mock_external_service_dependencies["provider_config"]
mock_provider_config.custom_configuration.provider = MagicMock() # Enable custom config
mock_provider_model_setting = mock_external_service_dependencies["provider_model_setting"]
mock_provider_model_setting.load_balancing_enabled = True
# Mock credential schema methods
mock_credential_schema = mock_external_service_dependencies["credential_schema"]
mock_credential_schema.credential_form_schemas = []
# Mock encrypter
mock_encrypter = mock_external_service_dependencies["encrypter"]
mock_encrypter.get_decrypt_decoding.return_value = ("key", "cipher")
# Act: Execute the method under test
service = ModelLoadBalancingService()
is_enabled, configs = service.get_load_balancing_configs(
tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm"
)
# Assert: Verify the expected outcomes
assert is_enabled is True
assert len(configs) == 2 # inherit config + existing config
# First config should be inherit config
assert configs[0]["name"] == "__inherit__"
assert configs[0]["enabled"] is True
# Second config should be the existing config
assert configs[1]["id"] == load_balancing_config.id
assert configs[1]["name"] == "config1"
# Verify database state
db.session.refresh(load_balancing_config)
assert load_balancing_config.id is not None
# Verify inherit config was created in database
inherit_configs = (
db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.name == "__inherit__").all()
)
assert len(inherit_configs) == 1