feat: improve multi model credentials (#25009)
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -13,6 +13,7 @@ from core.entities.provider_entities import (
|
||||
CustomModelConfiguration,
|
||||
ProviderQuotaType,
|
||||
QuotaConfiguration,
|
||||
UnaddedModelConfiguration,
|
||||
)
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
@@ -45,6 +46,7 @@ class CustomConfigurationResponse(BaseModel):
|
||||
current_credential_name: Optional[str] = None
|
||||
available_credentials: Optional[list[CredentialConfiguration]] = None
|
||||
custom_models: Optional[list[CustomModelConfiguration]] = None
|
||||
can_added_models: Optional[list[UnaddedModelConfiguration]] = None
|
||||
|
||||
|
||||
class SystemConfigurationResponse(BaseModel):
|
||||
|
@@ -3,6 +3,8 @@ import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import or_
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.entities.provider_configuration import ProviderConfiguration
|
||||
from core.helper import encrypter
|
||||
@@ -69,7 +71,7 @@ class ModelLoadBalancingService:
|
||||
provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
|
||||
|
||||
def get_load_balancing_configs(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str, config_from: str = ""
|
||||
) -> tuple[bool, list[dict]]:
|
||||
"""
|
||||
Get load balancing configurations.
|
||||
@@ -100,6 +102,11 @@ class ModelLoadBalancingService:
|
||||
if provider_model_setting and provider_model_setting.load_balancing_enabled:
|
||||
is_load_balancing_enabled = True
|
||||
|
||||
if config_from == "predefined-model":
|
||||
credential_source_type = "provider"
|
||||
else:
|
||||
credential_source_type = "custom_model"
|
||||
|
||||
# Get load balancing configurations
|
||||
load_balancing_configs = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
@@ -108,6 +115,10 @@ class ModelLoadBalancingService:
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
or_(
|
||||
LoadBalancingModelConfig.credential_source_type == credential_source_type,
|
||||
LoadBalancingModelConfig.credential_source_type.is_(None),
|
||||
),
|
||||
)
|
||||
.order_by(LoadBalancingModelConfig.created_at)
|
||||
.all()
|
||||
@@ -405,7 +416,7 @@ class ModelLoadBalancingService:
|
||||
self._clear_credentials_cache(tenant_id, config_id)
|
||||
else:
|
||||
# create load balancing config
|
||||
if name in {"__inherit__", "__delete__"}:
|
||||
if name == "__inherit__":
|
||||
raise ValueError("Invalid load balancing config name")
|
||||
|
||||
if credential_id:
|
||||
|
@@ -72,6 +72,7 @@ class ModelProviderService:
|
||||
|
||||
provider_config = provider_configuration.custom_configuration.provider
|
||||
model_config = provider_configuration.custom_configuration.models
|
||||
can_added_models = provider_configuration.custom_configuration.can_added_models
|
||||
|
||||
provider_response = ProviderResponse(
|
||||
tenant_id=tenant_id,
|
||||
@@ -95,6 +96,7 @@ class ModelProviderService:
|
||||
current_credential_name=getattr(provider_config, "current_credential_name", None),
|
||||
available_credentials=getattr(provider_config, "available_credentials", []),
|
||||
custom_models=model_config,
|
||||
can_added_models=can_added_models,
|
||||
),
|
||||
system_configuration=SystemConfigurationResponse(
|
||||
enabled=provider_configuration.system_configuration.enabled,
|
||||
@@ -152,7 +154,7 @@ class ModelProviderService:
|
||||
provider_configuration.validate_provider_credentials(credentials)
|
||||
|
||||
def create_provider_credential(
|
||||
self, tenant_id: str, provider: str, credentials: dict, credential_name: str
|
||||
self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None
|
||||
) -> None:
|
||||
"""
|
||||
Create and save new provider credentials.
|
||||
@@ -172,7 +174,7 @@ class ModelProviderService:
|
||||
provider: str,
|
||||
credentials: dict,
|
||||
credential_id: str,
|
||||
credential_name: str,
|
||||
credential_name: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
update a saved provider credential (by credential_id).
|
||||
@@ -249,7 +251,7 @@ class ModelProviderService:
|
||||
)
|
||||
|
||||
def create_model_credential(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None
|
||||
) -> None:
|
||||
"""
|
||||
create and save model credentials.
|
||||
@@ -278,7 +280,7 @@ class ModelProviderService:
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credential_id: str,
|
||||
credential_name: str,
|
||||
credential_name: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
update model credentials.
|
||||
|
Reference in New Issue
Block a user