feat: backend model load balancing support (#4927)

This commit is contained in:
takatost
2024-06-05 00:13:04 +08:00
committed by GitHub
parent 52ec152dd3
commit d1dbbc1e33
47 changed files with 2191 additions and 256 deletions

View File

@@ -11,6 +11,8 @@ from core.entities.provider_entities import (
CustomConfiguration,
CustomModelConfiguration,
CustomProviderConfiguration,
ModelLoadBalancingConfiguration,
ModelSettings,
QuotaConfiguration,
SystemConfiguration,
)
@@ -26,13 +28,16 @@ from core.model_runtime.model_providers import model_provider_factory
from extensions import ext_hosting_provider
from extensions.ext_database import db
from models.provider import (
LoadBalancingModelConfig,
Provider,
ProviderModel,
ProviderModelSetting,
ProviderQuotaType,
ProviderType,
TenantDefaultModel,
TenantPreferredModelProvider,
)
from services.feature_service import FeatureService
class ProviderManager:
@@ -98,6 +103,13 @@ class ProviderManager:
# Get All preferred provider types of the workspace
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
# Get All provider model settings
provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id)
# Get All load balancing configs
provider_name_to_provider_load_balancing_model_configs_dict \
= self._get_all_provider_load_balancing_configs(tenant_id)
provider_configurations = ProviderConfigurations(
tenant_id=tenant_id
)
@@ -147,13 +159,28 @@ class ProviderManager:
if system_configuration.enabled and has_valid_quota:
using_provider_type = ProviderType.SYSTEM
# Get provider load balancing configs
provider_model_settings = provider_name_to_provider_model_settings_dict.get(provider_name)
# Get provider load balancing configs
provider_load_balancing_configs \
= provider_name_to_provider_load_balancing_model_configs_dict.get(provider_name)
# Convert to model settings
model_settings = self._to_model_settings(
provider_entity=provider_entity,
provider_model_settings=provider_model_settings,
load_balancing_model_configs=provider_load_balancing_configs
)
provider_configuration = ProviderConfiguration(
tenant_id=tenant_id,
provider=provider_entity,
preferred_provider_type=preferred_provider_type,
using_provider_type=using_provider_type,
system_configuration=system_configuration,
custom_configuration=custom_configuration
custom_configuration=custom_configuration,
model_settings=model_settings
)
provider_configurations[provider_name] = provider_configuration
@@ -338,7 +365,7 @@ class ProviderManager:
"""
Get All preferred provider types of the workspace.
:param tenant_id:
:param tenant_id: workspace id
:return:
"""
preferred_provider_types = db.session.query(TenantPreferredModelProvider) \
@@ -353,6 +380,48 @@ class ProviderManager:
return provider_name_to_preferred_provider_type_records_dict
def _get_all_provider_model_settings(self, tenant_id: str) -> dict[str, list[ProviderModelSetting]]:
"""
Get All provider model settings of the workspace.
:param tenant_id: workspace id
:return:
"""
provider_model_settings = db.session.query(ProviderModelSetting) \
.filter(
ProviderModelSetting.tenant_id == tenant_id
).all()
provider_name_to_provider_model_settings_dict = defaultdict(list)
for provider_model_setting in provider_model_settings:
(provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name]
.append(provider_model_setting))
return provider_name_to_provider_model_settings_dict
def _get_all_provider_load_balancing_configs(self, tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]:
"""
Get All provider load balancing configs of the workspace.
:param tenant_id: workspace id
:return:
"""
model_load_balancing_enabled = FeatureService.get_features(tenant_id).model_load_balancing_enabled
if not model_load_balancing_enabled:
return dict()
provider_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id
).all()
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
for provider_load_balancing_config in provider_load_balancing_configs:
(provider_name_to_provider_load_balancing_model_configs_dict[provider_load_balancing_config.provider_name]
.append(provider_load_balancing_config))
return provider_name_to_provider_load_balancing_model_configs_dict
def _init_trial_provider_records(self, tenant_id: str,
provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]:
"""
@@ -726,3 +795,97 @@ class ProviderManager:
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables
def _to_model_settings(self, provider_entity: ProviderEntity,
provider_model_settings: Optional[list[ProviderModelSetting]] = None,
load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None) \
-> list[ModelSettings]:
"""
Convert to model settings.
:param provider_model_settings: provider model settings include enabled, load balancing enabled
:param load_balancing_model_configs: load balancing model configs
:return:
"""
# Get provider model credential secret variables
model_credential_secret_variables = self._extract_secret_variables(
provider_entity.model_credential_schema.credential_form_schemas
if provider_entity.model_credential_schema else []
)
model_settings = []
if not provider_model_settings:
return model_settings
for provider_model_setting in provider_model_settings:
load_balancing_configs = []
if provider_model_setting.load_balancing_enabled and load_balancing_model_configs:
for load_balancing_model_config in load_balancing_model_configs:
if (load_balancing_model_config.model_name == provider_model_setting.model_name
and load_balancing_model_config.model_type == provider_model_setting.model_type):
if not load_balancing_model_config.enabled:
continue
if not load_balancing_model_config.encrypted_config:
if load_balancing_model_config.name == "__inherit__":
load_balancing_configs.append(ModelLoadBalancingConfiguration(
id=load_balancing_model_config.id,
name=load_balancing_model_config.name,
credentials={}
))
continue
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=load_balancing_model_config.tenant_id,
identity_id=load_balancing_model_config.id,
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
)
# Get cached provider model credentials
cached_provider_model_credentials = provider_model_credentials_cache.get()
if not cached_provider_model_credentials:
try:
provider_model_credentials = json.loads(load_balancing_model_config.encrypted_config)
except JSONDecodeError:
continue
# Get decoding rsa key and cipher for decrypting credentials
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(
load_balancing_model_config.tenant_id)
for variable in model_credential_secret_variables:
if variable in provider_model_credentials:
try:
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_model_credentials.get(variable),
self.decoding_rsa_key,
self.decoding_cipher_rsa
)
except ValueError:
pass
# cache provider model credentials
provider_model_credentials_cache.set(
credentials=provider_model_credentials
)
else:
provider_model_credentials = cached_provider_model_credentials
load_balancing_configs.append(ModelLoadBalancingConfiguration(
id=load_balancing_model_config.id,
name=load_balancing_model_config.name,
credentials=provider_model_credentials
))
model_settings.append(
ModelSettings(
model=provider_model_setting.model_name,
model_type=ModelType.value_of(provider_model_setting.model_type),
enabled=provider_model_setting.enabled,
load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else []
)
)
return model_settings