feat: backend model load balancing support (#4927)
This commit is contained in:
@@ -11,6 +11,8 @@ from core.entities.provider_entities import (
|
||||
CustomConfiguration,
|
||||
CustomModelConfiguration,
|
||||
CustomProviderConfiguration,
|
||||
ModelLoadBalancingConfiguration,
|
||||
ModelSettings,
|
||||
QuotaConfiguration,
|
||||
SystemConfiguration,
|
||||
)
|
||||
@@ -26,13 +28,16 @@ from core.model_runtime.model_providers import model_provider_factory
|
||||
from extensions import ext_hosting_provider
|
||||
from extensions.ext_database import db
|
||||
from models.provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
ProviderModel,
|
||||
ProviderModelSetting,
|
||||
ProviderQuotaType,
|
||||
ProviderType,
|
||||
TenantDefaultModel,
|
||||
TenantPreferredModelProvider,
|
||||
)
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
class ProviderManager:
|
||||
@@ -98,6 +103,13 @@ class ProviderManager:
|
||||
# Get All preferred provider types of the workspace
|
||||
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
|
||||
|
||||
# Get All provider model settings
|
||||
provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id)
|
||||
|
||||
# Get All load balancing configs
|
||||
provider_name_to_provider_load_balancing_model_configs_dict \
|
||||
= self._get_all_provider_load_balancing_configs(tenant_id)
|
||||
|
||||
provider_configurations = ProviderConfigurations(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
@@ -147,13 +159,28 @@ class ProviderManager:
|
||||
if system_configuration.enabled and has_valid_quota:
|
||||
using_provider_type = ProviderType.SYSTEM
|
||||
|
||||
# Get provider load balancing configs
|
||||
provider_model_settings = provider_name_to_provider_model_settings_dict.get(provider_name)
|
||||
|
||||
# Get provider load balancing configs
|
||||
provider_load_balancing_configs \
|
||||
= provider_name_to_provider_load_balancing_model_configs_dict.get(provider_name)
|
||||
|
||||
# Convert to model settings
|
||||
model_settings = self._to_model_settings(
|
||||
provider_entity=provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=provider_load_balancing_configs
|
||||
)
|
||||
|
||||
provider_configuration = ProviderConfiguration(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_entity,
|
||||
preferred_provider_type=preferred_provider_type,
|
||||
using_provider_type=using_provider_type,
|
||||
system_configuration=system_configuration,
|
||||
custom_configuration=custom_configuration
|
||||
custom_configuration=custom_configuration,
|
||||
model_settings=model_settings
|
||||
)
|
||||
|
||||
provider_configurations[provider_name] = provider_configuration
|
||||
@@ -338,7 +365,7 @@ class ProviderManager:
|
||||
"""
|
||||
Get All preferred provider types of the workspace.
|
||||
|
||||
:param tenant_id:
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
preferred_provider_types = db.session.query(TenantPreferredModelProvider) \
|
||||
@@ -353,6 +380,48 @@ class ProviderManager:
|
||||
|
||||
return provider_name_to_preferred_provider_type_records_dict
|
||||
|
||||
def _get_all_provider_model_settings(self, tenant_id: str) -> dict[str, list[ProviderModelSetting]]:
|
||||
"""
|
||||
Get All provider model settings of the workspace.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
provider_model_settings = db.session.query(ProviderModelSetting) \
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == tenant_id
|
||||
).all()
|
||||
|
||||
provider_name_to_provider_model_settings_dict = defaultdict(list)
|
||||
for provider_model_setting in provider_model_settings:
|
||||
(provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name]
|
||||
.append(provider_model_setting))
|
||||
|
||||
return provider_name_to_provider_model_settings_dict
|
||||
|
||||
def _get_all_provider_load_balancing_configs(self, tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]:
|
||||
"""
|
||||
Get All provider load balancing configs of the workspace.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
model_load_balancing_enabled = FeatureService.get_features(tenant_id).model_load_balancing_enabled
|
||||
if not model_load_balancing_enabled:
|
||||
return dict()
|
||||
|
||||
provider_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id
|
||||
).all()
|
||||
|
||||
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
|
||||
for provider_load_balancing_config in provider_load_balancing_configs:
|
||||
(provider_name_to_provider_load_balancing_model_configs_dict[provider_load_balancing_config.provider_name]
|
||||
.append(provider_load_balancing_config))
|
||||
|
||||
return provider_name_to_provider_load_balancing_model_configs_dict
|
||||
|
||||
def _init_trial_provider_records(self, tenant_id: str,
|
||||
provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]:
|
||||
"""
|
||||
@@ -726,3 +795,97 @@ class ProviderManager:
|
||||
secret_input_form_variables.append(credential_form_schema.variable)
|
||||
|
||||
return secret_input_form_variables
|
||||
|
||||
def _to_model_settings(self, provider_entity: ProviderEntity,
|
||||
provider_model_settings: Optional[list[ProviderModelSetting]] = None,
|
||||
load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None) \
|
||||
-> list[ModelSettings]:
|
||||
"""
|
||||
Convert to model settings.
|
||||
|
||||
:param provider_model_settings: provider model settings include enabled, load balancing enabled
|
||||
:param load_balancing_model_configs: load balancing model configs
|
||||
:return:
|
||||
"""
|
||||
# Get provider model credential secret variables
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.model_credential_schema.credential_form_schemas
|
||||
if provider_entity.model_credential_schema else []
|
||||
)
|
||||
|
||||
model_settings = []
|
||||
if not provider_model_settings:
|
||||
return model_settings
|
||||
|
||||
for provider_model_setting in provider_model_settings:
|
||||
load_balancing_configs = []
|
||||
if provider_model_setting.load_balancing_enabled and load_balancing_model_configs:
|
||||
for load_balancing_model_config in load_balancing_model_configs:
|
||||
if (load_balancing_model_config.model_name == provider_model_setting.model_name
|
||||
and load_balancing_model_config.model_type == provider_model_setting.model_type):
|
||||
if not load_balancing_model_config.enabled:
|
||||
continue
|
||||
|
||||
if not load_balancing_model_config.encrypted_config:
|
||||
if load_balancing_model_config.name == "__inherit__":
|
||||
load_balancing_configs.append(ModelLoadBalancingConfiguration(
|
||||
id=load_balancing_model_config.id,
|
||||
name=load_balancing_model_config.name,
|
||||
credentials={}
|
||||
))
|
||||
continue
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=load_balancing_model_config.tenant_id,
|
||||
identity_id=load_balancing_model_config.id,
|
||||
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
|
||||
)
|
||||
|
||||
# Get cached provider model credentials
|
||||
cached_provider_model_credentials = provider_model_credentials_cache.get()
|
||||
|
||||
if not cached_provider_model_credentials:
|
||||
try:
|
||||
provider_model_credentials = json.loads(load_balancing_model_config.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Get decoding rsa key and cipher for decrypting credentials
|
||||
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
||||
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(
|
||||
load_balancing_model_config.tenant_id)
|
||||
|
||||
for variable in model_credential_secret_variables:
|
||||
if variable in provider_model_credentials:
|
||||
try:
|
||||
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_model_credentials.get(variable),
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# cache provider model credentials
|
||||
provider_model_credentials_cache.set(
|
||||
credentials=provider_model_credentials
|
||||
)
|
||||
else:
|
||||
provider_model_credentials = cached_provider_model_credentials
|
||||
|
||||
load_balancing_configs.append(ModelLoadBalancingConfiguration(
|
||||
id=load_balancing_model_config.id,
|
||||
name=load_balancing_model_config.name,
|
||||
credentials=provider_model_credentials
|
||||
))
|
||||
|
||||
model_settings.append(
|
||||
ModelSettings(
|
||||
model=provider_model_setting.model_name,
|
||||
model_type=ModelType.value_of(provider_model_setting.model_type),
|
||||
enabled=provider_model_setting.enabled,
|
||||
load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else []
|
||||
)
|
||||
)
|
||||
|
||||
return model_settings
|
||||
|
Reference in New Issue
Block a user