feat: improve multi model credentials (#25009)

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
非法操作
2025-09-03 13:52:31 +08:00
committed by GitHub
parent 9e125e2029
commit b673560b92
8 changed files with 332 additions and 150 deletions

View File

@@ -13,6 +13,7 @@ from core.entities.provider_entities import (
CustomModelConfiguration,
ProviderQuotaType,
QuotaConfiguration,
UnaddedModelConfiguration,
)
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType
@@ -45,6 +46,7 @@ class CustomConfigurationResponse(BaseModel):
current_credential_name: Optional[str] = None
available_credentials: Optional[list[CredentialConfiguration]] = None
custom_models: Optional[list[CustomModelConfiguration]] = None
can_added_models: Optional[list[UnaddedModelConfiguration]] = None
class SystemConfigurationResponse(BaseModel):

View File

@@ -3,6 +3,8 @@ import logging
from json import JSONDecodeError
from typing import Optional, Union
from sqlalchemy import or_
from constants import HIDDEN_VALUE
from core.entities.provider_configuration import ProviderConfiguration
from core.helper import encrypter
@@ -69,7 +71,7 @@ class ModelLoadBalancingService:
provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
def get_load_balancing_configs(
self, tenant_id: str, provider: str, model: str, model_type: str
self, tenant_id: str, provider: str, model: str, model_type: str, config_from: str = ""
) -> tuple[bool, list[dict]]:
"""
Get load balancing configurations.
@@ -100,6 +102,11 @@ class ModelLoadBalancingService:
if provider_model_setting and provider_model_setting.load_balancing_enabled:
is_load_balancing_enabled = True
if config_from == "predefined-model":
credential_source_type = "provider"
else:
credential_source_type = "custom_model"
# Get load balancing configurations
load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
@@ -108,6 +115,10 @@ class ModelLoadBalancingService:
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
or_(
LoadBalancingModelConfig.credential_source_type == credential_source_type,
LoadBalancingModelConfig.credential_source_type.is_(None),
),
)
.order_by(LoadBalancingModelConfig.created_at)
.all()
@@ -405,7 +416,7 @@ class ModelLoadBalancingService:
self._clear_credentials_cache(tenant_id, config_id)
else:
# create load balancing config
if name in {"__inherit__", "__delete__"}:
if name == "__inherit__":
raise ValueError("Invalid load balancing config name")
if credential_id:

View File

@@ -72,6 +72,7 @@ class ModelProviderService:
provider_config = provider_configuration.custom_configuration.provider
model_config = provider_configuration.custom_configuration.models
can_added_models = provider_configuration.custom_configuration.can_added_models
provider_response = ProviderResponse(
tenant_id=tenant_id,
@@ -95,6 +96,7 @@ class ModelProviderService:
current_credential_name=getattr(provider_config, "current_credential_name", None),
available_credentials=getattr(provider_config, "available_credentials", []),
custom_models=model_config,
can_added_models=can_added_models,
),
system_configuration=SystemConfigurationResponse(
enabled=provider_configuration.system_configuration.enabled,
@@ -152,7 +154,7 @@ class ModelProviderService:
provider_configuration.validate_provider_credentials(credentials)
def create_provider_credential(
self, tenant_id: str, provider: str, credentials: dict, credential_name: str
self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None
) -> None:
"""
Create and save new provider credentials.
@@ -172,7 +174,7 @@ class ModelProviderService:
provider: str,
credentials: dict,
credential_id: str,
credential_name: str,
credential_name: str | None,
) -> None:
"""
update a saved provider credential (by credential_id).
@@ -249,7 +251,7 @@ class ModelProviderService:
)
def create_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None
) -> None:
"""
create and save model credentials.
@@ -278,7 +280,7 @@ class ModelProviderService:
model: str,
credentials: dict,
credential_id: str,
credential_name: str,
credential_name: str | None,
) -> None:
"""
update model credentials.