Files
dify/api/core/entities/provider_configuration.py
Asuka Minato a78339a040 remove bare list, dict, Sequence, None, Any (#25058)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2025-09-06 03:32:23 +08:00

1840 lines
77 KiB
Python

import json
import logging
import re
from collections import defaultdict
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
from typing import Optional
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from constants import HIDDEN_VALUE
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from core.entities.provider_entities import (
CustomConfiguration,
ModelSettings,
SystemConfiguration,
SystemConfigurationStatus,
)
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.entities.provider_entities import (
ConfigurateMethod,
CredentialFormSchema,
FormType,
ProviderEntity,
)
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.plugin.entities.plugin import ModelProviderID
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.provider import (
LoadBalancingModelConfig,
Provider,
ProviderCredential,
ProviderModel,
ProviderModelCredential,
ProviderModelSetting,
ProviderType,
TenantPreferredModelProvider,
)
logger = logging.getLogger(__name__)
original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {}
class ProviderConfiguration(BaseModel):
"""
Provider configuration entity for managing model provider settings.
This class handles:
- Provider credentials CRUD and switch
- Custom Model credentials CRUD and switch
- System vs custom provider switching
- Load balancing configurations
- Model enablement/disablement
TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified
"""
tenant_id: str
provider: ProviderEntity
preferred_provider_type: ProviderType
using_provider_type: ProviderType
system_configuration: SystemConfiguration
custom_configuration: CustomConfiguration
model_settings: list[ModelSettings]
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def __init__(self, **data):
super().__init__(**data)
if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in self.provider.configurate_methods:
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
if (
any(
len(quota_configuration.restrict_models) > 0
for quota_configuration in self.system_configuration.quota_configurations
)
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
):
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
"""
Get current credentials.
:param model_type: model type
:param model: model name
:return:
"""
if self.model_settings:
# check if model is disabled by admin
for model_setting in self.model_settings:
if model_setting.model_type == model_type and model_setting.model == model:
if not model_setting.enabled:
raise ValueError(f"Model {model} is disabled.")
if self.using_provider_type == ProviderType.SYSTEM:
restrict_models = []
for quota_configuration in self.system_configuration.quota_configurations:
if self.system_configuration.current_quota_type != quota_configuration.quota_type:
continue
restrict_models = quota_configuration.restrict_models
copy_credentials = (
self.system_configuration.credentials.copy() if self.system_configuration.credentials else {}
)
if restrict_models:
for restrict_model in restrict_models:
if (
restrict_model.model_type == model_type
and restrict_model.model == model
and restrict_model.base_model_name
):
copy_credentials["base_model_name"] = restrict_model.base_model_name
return copy_credentials
else:
credentials = None
if self.custom_configuration.models:
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type == model_type and model_configuration.model == model:
credentials = model_configuration.credentials
break
if not credentials and self.custom_configuration.provider:
credentials = self.custom_configuration.provider.credentials
return credentials
def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]:
"""
Get system configuration status.
:return:
"""
if self.system_configuration.enabled is False:
return SystemConfigurationStatus.UNSUPPORTED
current_quota_type = self.system_configuration.current_quota_type
current_quota_configuration = next(
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
)
if current_quota_configuration is None:
return None
if not current_quota_configuration:
return SystemConfigurationStatus.UNSUPPORTED
return (
SystemConfigurationStatus.ACTIVE
if current_quota_configuration.is_valid
else SystemConfigurationStatus.QUOTA_EXCEEDED
)
def is_custom_configuration_available(self) -> bool:
"""
Check custom configuration available.
:return:
"""
has_provider_credentials = (
self.custom_configuration.provider is not None
and len(self.custom_configuration.provider.available_credentials) > 0
)
has_model_configurations = len(self.custom_configuration.models) > 0
return has_provider_credentials or has_model_configurations
def _get_provider_record(self, session: Session) -> Provider | None:
"""
Get custom provider record.
"""
# get provider
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
stmt = select(Provider).where(
Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_name.in_(provider_names),
)
return session.execute(stmt).scalar_one_or_none()
def _get_specific_provider_credential(self, credential_id: str) -> dict | None:
"""
Get a specific provider credential by ID.
:param credential_id: Credential ID
:return:
"""
# Extract secret variables from provider credential schema
credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema
else []
)
with Session(db.engine) as session:
# Prefer the actual provider record name if exists (to handle aliased provider names)
provider_record = self._get_provider_record(session)
provider_name = provider_record.provider_name if provider_record else self.provider.provider
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == provider_name,
)
credential = session.execute(stmt).scalar_one_or_none()
if not credential or not credential.encrypted_config:
raise ValueError(f"Credential with id {credential_id} not found.")
try:
credentials = json.loads(credential.encrypted_config)
except JSONDecodeError:
credentials = {}
# Decrypt secret variables
for key in credential_secret_variables:
if key in credentials and credentials[key] is not None:
try:
credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key])
except Exception:
pass
return self.obfuscated_credentials(
credentials=credentials,
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema
else [],
)
def _check_provider_credential_name_exists(
self, credential_name: str, session: Session, exclude_id: str | None = None
) -> bool:
"""
not allowed same name when create or update a credential
"""
stmt = select(ProviderCredential.id).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.credential_name == credential_name,
)
if exclude_id:
stmt = stmt.where(ProviderCredential.id != exclude_id)
return session.execute(stmt).scalar_one_or_none() is not None
def get_provider_credential(self, credential_id: str | None = None) -> dict | None:
"""
Get provider credentials.
:param credential_id: if provided, return the specified credential
:return:
"""
if credential_id:
return self._get_specific_provider_credential(credential_id)
# Default behavior: return current active provider credentials
credentials = self.custom_configuration.provider.credentials if self.custom_configuration.provider else {}
return self.obfuscated_credentials(
credentials=credentials,
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema
else [],
)
def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None):
"""
Validate custom credentials.
:param credentials: provider credentials
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
:param session: optional database session
:return:
"""
def _validate(s: Session):
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema
else []
)
if credential_id:
try:
stmt = select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.id == credential_id,
)
credential_record = s.execute(stmt).scalar_one_or_none()
# fix origin data
if credential_record and credential_record.encrypted_config:
if not credential_record.encrypted_config.startswith("{"):
original_credentials = {"openai_api_key": credential_record.encrypted_config}
else:
original_credentials = json.loads(credential_record.encrypted_config)
else:
original_credentials = {}
except JSONDecodeError:
original_credentials = {}
# encrypt credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(
tenant_id=self.tenant_id, token=original_credentials[key]
)
model_provider_factory = ModelProviderFactory(self.tenant_id)
validated_credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider, credentials=credentials
)
for key, value in validated_credentials.items():
if key in provider_credential_secret_variables:
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return validated_credentials
if session:
return _validate(session)
else:
with Session(db.engine) as new_session:
return _validate(new_session)
def _generate_provider_credential_name(self, session) -> str:
"""
Generate a unique credential name for provider.
:return: credential name
"""
return self._generate_next_api_key_name(
session=session,
query_factory=lambda: select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
),
)
def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str:
"""
Generate a unique credential name for custom model.
:return: credential name
"""
return self._generate_next_api_key_name(
session=session,
query_factory=lambda: select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
),
)
def _generate_next_api_key_name(self, session, query_factory) -> str:
"""
Generate next available API KEY name by finding the highest numbered suffix.
:param session: database session
:param query_factory: function that returns the SQLAlchemy query
:return: next available API KEY name
"""
try:
stmt = query_factory()
credential_records = session.execute(stmt).scalars().all()
if not credential_records:
return "API KEY 1"
# Extract numbers from API KEY pattern using list comprehension
pattern = re.compile(r"^API KEY\s+(\d+)$")
numbers = [
int(match.group(1))
for cr in credential_records
if cr.credential_name and (match := pattern.match(cr.credential_name.strip()))
]
# Return next sequential number
next_number = max(numbers, default=0) + 1
return f"API KEY {next_number}"
except Exception as e:
logger.warning("Error generating next credential name: %s", str(e))
return "API KEY 1"
def create_provider_credential(self, credentials: dict, credential_name: str | None):
"""
Add custom provider credentials.
:param credentials: provider credentials
:param credential_name: credential name
:return:
"""
with Session(db.engine) as session:
if credential_name:
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
else:
credential_name = self._generate_provider_credential_name(session)
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
provider_record = self._get_provider_record(session)
try:
new_record = ProviderCredential(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
encrypted_config=json.dumps(credentials),
credential_name=credential_name,
)
session.add(new_record)
session.flush()
if not provider_record:
# If provider record does not exist, create it
provider_record = Provider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
provider_type=ProviderType.CUSTOM.value,
is_valid=True,
credential_id=new_record.id,
)
session.add(provider_record)
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER,
)
provider_model_credentials_cache.delete()
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
session.commit()
except Exception:
session.rollback()
raise
def update_provider_credential(
self,
credentials: dict,
credential_id: str,
credential_name: str | None,
):
"""
update a saved provider credential (by credential_id).
:param credentials: provider credentials
:param credential_id: credential id
:param credential_name: credential name
:return:
"""
with Session(db.engine) as session:
if credential_name and self._check_provider_credential_name_exists(
credential_name=credential_name, session=session, exclude_id=credential_id
):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
credentials = self.validate_provider_credentials(
credentials=credentials, credential_id=credential_id, session=session
)
provider_record = self._get_provider_record(session)
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
)
# Get the credential record to update
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
raise ValueError("Credential record not found.")
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.updated_at = naive_utc_now()
if credential_name:
credential_record.credential_name = credential_name
session.commit()
if provider_record and provider_record.credential_id == credential_id:
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER,
)
provider_model_credentials_cache.delete()
self._update_load_balancing_configs_with_credential(
credential_id=credential_id,
credential_record=credential_record,
credential_source="provider",
session=session,
)
except Exception:
session.rollback()
raise
def _update_load_balancing_configs_with_credential(
self,
credential_id: str,
credential_record: ProviderCredential | ProviderModelCredential,
credential_source: str,
session: Session,
):
"""
Update load balancing configurations that reference the given credential_id.
:param credential_id: credential id
:param credential_record: the encrypted_config and credential_name
:param credential_source: the credential comes from the provider_credential(`provider`)
or the provider_model_credential(`custom_model`)
:param session: the database session
:return:
"""
# Find all load balancing configs that use this credential_id
stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == credential_source,
)
load_balancing_configs = session.execute(stmt).scalars().all()
if not load_balancing_configs:
return
# Update each load balancing config with the new credentials
for lb_config in load_balancing_configs:
# Update the encrypted_config with the new credentials
lb_config.encrypted_config = credential_record.encrypted_config
lb_config.name = credential_record.credential_name
lb_config.updated_at = naive_utc_now()
# Clear cache for this load balancing config
lb_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=lb_config.id,
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
)
lb_credentials_cache.delete()
session.commit()
def delete_provider_credential(self, credential_id: str):
"""
Delete a saved provider credential (by credential_id).
:param credential_id: credential id
:return:
"""
with Session(db.engine) as session:
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
)
# Get the credential record to update
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
raise ValueError("Credential record not found.")
# Check if this credential is used in load balancing configs
lb_stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "provider",
)
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
try:
for lb_config in lb_configs_using_credential:
lb_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=lb_config.id,
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
)
lb_credentials_cache.delete()
session.delete(lb_config)
# Check if this is the currently active credential
provider_record = self._get_provider_record(session)
# Check available credentials count BEFORE deleting
# if this is the last credential, we need to delete the provider record
count_stmt = select(func.count(ProviderCredential.id)).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
)
available_credentials_count = session.execute(count_stmt).scalar() or 0
session.delete(credential_record)
if provider_record and available_credentials_count <= 1:
# If all credentials are deleted, delete the provider record
session.delete(provider_record)
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER,
)
provider_model_credentials_cache.delete()
self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session)
elif provider_record and provider_record.credential_id == credential_id:
provider_record.credential_id = None
provider_record.updated_at = naive_utc_now()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER,
)
provider_model_credentials_cache.delete()
self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session)
session.commit()
except Exception:
session.rollback()
raise
def switch_active_provider_credential(self, credential_id: str):
"""
Switch active provider credential (copy the selected one into current active snapshot).
:param credential_id: credential id
:return:
"""
with Session(db.engine) as session:
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
raise ValueError("Credential record not found.")
provider_record = self._get_provider_record(session)
if not provider_record:
raise ValueError("Provider record not found.")
try:
provider_record.credential_id = credential_record.id
provider_record.updated_at = naive_utc_now()
session.commit()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER,
)
provider_model_credentials_cache.delete()
self.switch_preferred_provider_type(ProviderType.CUSTOM, session=session)
except Exception:
session.rollback()
raise
def _get_custom_model_record(
self,
model_type: ModelType,
model: str,
session: Session,
) -> ProviderModel | None:
"""
Get custom model credentials.
"""
# get provider model
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
stmt = select(ProviderModel).where(
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name.in_(provider_names),
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type(),
)
return session.execute(stmt).scalar_one_or_none()
def _get_specific_custom_model_credential(
self, model_type: ModelType, model: str, credential_id: str
) -> dict | None:
"""
Get a specific provider credential by ID.
:param credential_id: Credential ID
:return:
"""
model_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema
else []
)
with Session(db.engine) as session:
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record or not credential_record.encrypted_config:
raise ValueError(f"Credential with id {credential_id} not found.")
try:
credentials = json.loads(credential_record.encrypted_config)
except JSONDecodeError:
credentials = {}
# Decrypt secret variables
for key in model_credential_secret_variables:
if key in credentials and credentials[key] is not None:
try:
credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key])
except Exception:
pass
current_credential_id = credential_record.id
current_credential_name = credential_record.credential_name
credentials = self.obfuscated_credentials(
credentials=credentials,
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema
else [],
)
return {
"current_credential_id": current_credential_id,
"current_credential_name": current_credential_name,
"credentials": credentials,
}
def _check_custom_model_credential_name_exists(
self, model_type: ModelType, model: str, credential_name: str, session: Session, exclude_id: str | None = None
) -> bool:
"""
not allowed same name when create or update a credential
"""
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.credential_name == credential_name,
)
if exclude_id:
stmt = stmt.where(ProviderModelCredential.id != exclude_id)
return session.execute(stmt).scalar_one_or_none() is not None
def get_custom_model_credential(
self, model_type: ModelType, model: str, credential_id: str | None
) -> Optional[dict]:
"""
Get custom model credentials.
:param model_type: model type
:param model: model name
:return:
"""
# If credential_id is provided, return the specific credential
if credential_id:
return self._get_specific_custom_model_credential(
model_type=model_type, model=model, credential_id=credential_id
)
for model_configuration in self.custom_configuration.models:
if (
model_configuration.model_type == model_type
and model_configuration.model == model
and model_configuration.credentials
):
current_credential_id = model_configuration.current_credential_id
current_credential_name = model_configuration.current_credential_name
credentials = self.obfuscated_credentials(
credentials=model_configuration.credentials,
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema
else [],
)
return {
"current_credential_id": current_credential_id,
"current_credential_name": current_credential_name,
"credentials": credentials,
}
return None
def validate_custom_model_credentials(
self,
model_type: ModelType,
model: str,
credentials: dict,
credential_id: str = "",
session: Session | None = None,
):
"""
Validate custom model credentials.
:param model_type: model type
:param model: model name
:param credentials: model credentials dict
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
:return:
"""
def _validate(s: Session):
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema
else []
)
if credential_id:
try:
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
credential_record = s.execute(stmt).scalar_one_or_none()
original_credentials = (
json.loads(credential_record.encrypted_config)
if credential_record and credential_record.encrypted_config
else {}
)
except JSONDecodeError:
original_credentials = {}
# decrypt credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(
tenant_id=self.tenant_id, token=original_credentials[key]
)
model_provider_factory = ModelProviderFactory(self.tenant_id)
validated_credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
for key, value in validated_credentials.items():
if key in provider_credential_secret_variables:
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return validated_credentials
if session:
return _validate(session)
else:
with Session(db.engine) as new_session:
return _validate(new_session)
def create_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
) -> None:
"""
Create a custom model credential.
:param model_type: model type
:param model: model name
:param credentials: model credentials dict
:return:
"""
with Session(db.engine) as session:
if credential_name:
if self._check_custom_model_credential_name_exists(
model=model, model_type=model_type, credential_name=credential_name, session=session
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
else:
credential_name = self._generate_custom_model_credential_name(
model=model, model_type=model_type, session=session
)
# validate custom model config
credentials = self.validate_custom_model_credentials(
model_type=model_type, model=model, credentials=credentials, session=session
)
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
try:
credential = ProviderModelCredential(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_name=model,
model_type=model_type.to_origin_model_type(),
encrypted_config=json.dumps(credentials),
credential_name=credential_name,
)
session.add(credential)
session.flush()
# save provider model
if not provider_model_record:
provider_model_record = ProviderModel(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_name=model,
model_type=model_type.to_origin_model_type(),
credential_id=credential.id,
is_valid=True,
)
session.add(provider_model_record)
session.commit()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL,
)
provider_model_credentials_cache.delete()
except Exception:
session.rollback()
raise
def update_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
) -> None:
"""
Update a custom model credential.
:param model_type: model type
:param model: model name
:param credentials: model credentials dict
:param credential_name: credential name
:param credential_id: credential id
:return:
"""
with Session(db.engine) as session:
if credential_name and self._check_custom_model_credential_name_exists(
model=model,
model_type=model_type,
credential_name=credential_name,
session=session,
exclude_id=credential_id,
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
# validate custom model config
credentials = self.validate_custom_model_credentials(
model_type=model_type,
model=model,
credentials=credentials,
credential_id=credential_id,
session=session,
)
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
raise ValueError("Credential record not found.")
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.updated_at = naive_utc_now()
if credential_name:
credential_record.credential_name = credential_name
session.commit()
if provider_model_record and provider_model_record.credential_id == credential_id:
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL,
)
provider_model_credentials_cache.delete()
self._update_load_balancing_configs_with_credential(
credential_id=credential_id,
credential_record=credential_record,
credential_source="custom_model",
session=session,
)
except Exception:
session.rollback()
raise
def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str):
"""
Delete a saved provider credential (by credential_id).
:param credential_id: credential id
:return:
"""
with Session(db.engine) as session:
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
raise ValueError("Credential record not found.")
lb_stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "custom_model",
)
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
try:
for lb_config in lb_configs_using_credential:
lb_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=lb_config.id,
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
)
lb_credentials_cache.delete()
session.delete(lb_config)
# Check if this is the currently active credential
provider_model_record = self._get_custom_model_record(model_type, model, session=session)
# Check available credentials count BEFORE deleting
# if this is the last credential, we need to delete the custom model record
count_stmt = select(func.count(ProviderModelCredential.id)).where(
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
available_credentials_count = session.execute(count_stmt).scalar() or 0
session.delete(credential_record)
if provider_model_record and available_credentials_count <= 1:
# If all credentials are deleted, delete the custom model record
session.delete(provider_model_record)
elif provider_model_record and provider_model_record.credential_id == credential_id:
provider_model_record.credential_id = None
provider_model_record.updated_at = naive_utc_now()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER,
)
provider_model_credentials_cache.delete()
session.commit()
except Exception:
session.rollback()
raise
def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str):
"""
if model list exist this custom model, switch the custom model credential.
if model list not exist this custom model, use the credential to add a new custom model record.
:param model_type: model type
:param model: model name
:param credential_id: credential id
:return:
"""
with Session(db.engine) as session:
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
raise ValueError("Credential record not found.")
# validate custom model config
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
if not provider_model_record:
# create provider model record
provider_model_record = ProviderModel(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_name=model,
model_type=model_type.to_origin_model_type(),
is_valid=True,
credential_id=credential_id,
)
else:
if provider_model_record.credential_id == credential_record.id:
raise ValueError("Can't add same credential")
provider_model_record.credential_id = credential_record.id
provider_model_record.updated_at = naive_utc_now()
session.add(provider_model_record)
session.commit()
def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str):
"""
switch the custom model credential.
:param model_type: model type
:param model: model name
:param credential_id: credential id
:return:
"""
with Session(db.engine) as session:
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
raise ValueError("Credential record not found.")
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
if not provider_model_record:
raise ValueError("The custom model record not found.")
provider_model_record.credential_id = credential_record.id
provider_model_record.updated_at = naive_utc_now()
session.add(provider_model_record)
session.commit()
def delete_custom_model(self, model_type: ModelType, model: str):
"""
Delete custom model.
:param model_type: model type
:param model: model name
:return:
"""
with Session(db.engine) as session:
# get provider model
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
# delete provider model
if provider_model_record:
session.delete(provider_model_record)
session.commit()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL,
)
provider_model_credentials_cache.delete()
def _get_provider_model_setting(
self, model_type: ModelType, model: str, session: Session
) -> ProviderModelSetting | None:
"""
Get provider model setting.
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
stmt = select(ProviderModelSetting).where(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
return session.execute(stmt).scalars().first()
def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
"""
Enable model.
:param model_type: model type
:param model: model name
:return:
"""
with Session(db.engine) as session:
model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
if model_setting:
model_setting.enabled = True
model_setting.updated_at = naive_utc_now()
else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
enabled=True,
)
session.add(model_setting)
session.commit()
return model_setting
def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
"""
Disable model.
:param model_type: model type
:param model: model name
:return:
"""
with Session(db.engine) as session:
model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
if model_setting:
model_setting.enabled = False
model_setting.updated_at = naive_utc_now()
else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
enabled=False,
)
session.add(model_setting)
session.commit()
return model_setting
def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
"""
Get provider model setting.
:param model_type: model type
:param model: model name
:return:
"""
with Session(db.engine) as session:
return self._get_provider_model_setting(model_type=model_type, model=model, session=session)
def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
"""
Enable model load balancing.
:param model_type: model type
:param model: model name
:return:
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
with Session(db.engine) as session:
stmt = select(func.count(LoadBalancingModelConfig.id)).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
load_balancing_config_count = session.execute(stmt).scalar() or 0
if load_balancing_config_count <= 1:
raise ValueError("Model load balancing configuration must be more than 1.")
model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
if model_setting:
model_setting.load_balancing_enabled = True
model_setting.updated_at = naive_utc_now()
else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
load_balancing_enabled=True,
)
session.add(model_setting)
session.commit()
return model_setting
def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
"""
Disable model load balancing.
:param model_type: model type
:param model: model name
:return:
"""
with Session(db.engine) as session:
model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
if model_setting:
model_setting.load_balancing_enabled = False
model_setting.updated_at = naive_utc_now()
else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
load_balancing_enabled=False,
)
session.add(model_setting)
session.commit()
return model_setting
def get_model_type_instance(self, model_type: ModelType) -> AIModel:
"""
Get current model type instance.
:param model_type: model type
:return:
"""
model_provider_factory = ModelProviderFactory(self.tenant_id)
# Get model instance of LLM
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None:
"""
Get model schema
"""
model_provider_factory = ModelProviderFactory(self.tenant_id)
return model_provider_factory.get_model_schema(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None):
"""
Switch preferred provider type.
:param provider_type:
:return:
"""
if provider_type == self.preferred_provider_type:
return
if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
return
def _switch(s: Session):
# get preferred provider
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
stmt = select(TenantPreferredModelProvider).where(
TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name.in_(provider_names),
)
preferred_model_provider = s.execute(stmt).scalars().first()
if preferred_model_provider:
preferred_model_provider.preferred_provider_type = provider_type.value
else:
preferred_model_provider = TenantPreferredModelProvider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
preferred_provider_type=provider_type.value,
)
s.add(preferred_model_provider)
s.commit()
if session:
return _switch(session)
else:
with Session(db.engine) as session:
return _switch(session)
def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
"""
Extract secret input form variables.
:param credential_form_schemas:
:return:
"""
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type == FormType.SECRET_INPUT:
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]):
"""
Obfuscated credentials.
:param credentials: credentials
:param credential_form_schemas: credential form schemas
:return:
"""
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(credential_form_schemas)
# Obfuscate provider credentials
copy_credentials = credentials.copy()
for key, value in copy_credentials.items():
if key in credential_secret_variables:
copy_credentials[key] = encrypter.obfuscated_token(value)
return copy_credentials
def get_provider_model(
self, model_type: ModelType, model: str, only_active: bool = False
) -> Optional[ModelWithProviderEntity]:
"""
Get provider model.
:param model_type: model type
:param model: model name
:param only_active: return active model only
:return:
"""
provider_models = self.get_provider_models(model_type, only_active, model)
for provider_model in provider_models:
if provider_model.model == model:
return provider_model
return None
def get_provider_models(
self, model_type: Optional[ModelType] = None, only_active: bool = False, model: Optional[str] = None
) -> list[ModelWithProviderEntity]:
"""
Get provider models.
:param model_type: model type
:param only_active: only active models
:param model: model name
:return:
"""
model_provider_factory = ModelProviderFactory(self.tenant_id)
provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
model_types: list[ModelType] = []
if model_type:
model_types.append(model_type)
else:
model_types = list(provider_schema.supported_model_types)
# Group model settings by model type and model
model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
for model_setting in self.model_settings:
model_setting_map[model_setting.model_type][model_setting.model] = model_setting
if self.using_provider_type == ProviderType.SYSTEM:
provider_models = self._get_system_provider_models(
model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
)
else:
provider_models = self._get_custom_provider_models(
model_types=model_types,
provider_schema=provider_schema,
model_setting_map=model_setting_map,
model=model,
)
if only_active:
provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
# resort provider_models
# Optimize sorting logic: first sort by provider.position order, then by model_type.value
# Get the position list for model types (retrieve only once for better performance)
model_type_positions = {}
if hasattr(self.provider, "position") and self.provider.position:
model_type_positions = self.provider.position
def get_sort_key(model: ModelWithProviderEntity):
# Get the position list for the current model type
positions = model_type_positions.get(model.model_type.value, [])
# If the model name is in the position list, use its index for sorting
# Otherwise use a large value (list length) to place undefined models at the end
position_index = positions.index(model.model) if model.model in positions else len(positions)
# Return composite sort key: (model_type value, model position index)
return (model.model_type.value, position_index)
# Sort using the composite sort key
return sorted(provider_models, key=get_sort_key)
def _get_system_provider_models(
self,
model_types: Sequence[ModelType],
provider_schema: ProviderEntity,
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
) -> list[ModelWithProviderEntity]:
"""
Get system provider models.
:param model_types: model types
:param provider_schema: provider schema
:param model_setting_map: model setting map
:return:
"""
provider_models = []
for model_type in model_types:
for m in provider_schema.models:
if m.model_type != model_type:
continue
status = ModelStatus.ACTIVE
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
model_setting = model_setting_map[m.model_type][m.model]
if model_setting.enabled is False:
status = ModelStatus.DISABLED
provider_models.append(
ModelWithProviderEntity(
model=m.model,
label=m.label,
model_type=m.model_type,
features=m.features,
fetch_from=m.fetch_from,
model_properties=m.model_properties,
deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=status,
)
)
if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in provider_schema.configurate_methods:
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
should_use_custom_model = False
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
should_use_custom_model = True
for quota_configuration in self.system_configuration.quota_configurations:
if self.system_configuration.current_quota_type != quota_configuration.quota_type:
continue
restrict_models = quota_configuration.restrict_models
if len(restrict_models) == 0:
break
if should_use_custom_model:
if original_provider_configurate_methods[self.provider.provider] == [
ConfigurateMethod.CUSTOMIZABLE_MODEL
]:
# only customizable model
for restrict_model in restrict_models:
copy_credentials = (
self.system_configuration.credentials.copy()
if self.system_configuration.credentials
else {}
)
if restrict_model.base_model_name:
copy_credentials["base_model_name"] = restrict_model.base_model_name
try:
custom_model_schema = self.get_model_schema(
model_type=restrict_model.model_type,
model=restrict_model.model,
credentials=copy_credentials,
)
except Exception as ex:
logger.warning("get custom model schema failed, %s", ex)
continue
if not custom_model_schema:
continue
if custom_model_schema.model_type not in model_types:
continue
status = ModelStatus.ACTIVE
if (
custom_model_schema.model_type in model_setting_map
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
):
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
if model_setting.enabled is False:
status = ModelStatus.DISABLED
provider_models.append(
ModelWithProviderEntity(
model=custom_model_schema.model,
label=custom_model_schema.label,
model_type=custom_model_schema.model_type,
features=custom_model_schema.features,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=status,
)
)
# if llm name not in restricted llm list, remove it
restrict_model_names = [rm.model for rm in restrict_models]
for model in provider_models:
if model.model_type == ModelType.LLM and model.model not in restrict_model_names:
model.status = ModelStatus.NO_PERMISSION
elif not quota_configuration.is_valid:
model.status = ModelStatus.QUOTA_EXCEEDED
return provider_models
def _get_custom_provider_models(
self,
model_types: Sequence[ModelType],
provider_schema: ProviderEntity,
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
model: Optional[str] = None,
) -> list[ModelWithProviderEntity]:
"""
Get custom provider models.
:param model_types: model types
:param provider_schema: provider schema
:param model_setting_map: model setting map
:return:
"""
provider_models = []
credentials = None
if self.custom_configuration.provider:
credentials = self.custom_configuration.provider.credentials
for model_type in model_types:
if model_type not in self.provider.supported_model_types:
continue
for m in provider_schema.models:
if m.model_type != model_type:
continue
status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
load_balancing_enabled = False
has_invalid_load_balancing_configs = False
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
model_setting = model_setting_map[m.model_type][m.model]
if model_setting.enabled is False:
status = ModelStatus.DISABLED
provider_model_lb_configs = [
config
for config in model_setting.load_balancing_configs
if config.credential_source_type != "custom_model"
]
load_balancing_enabled = model_setting.load_balancing_enabled
# when the user enable load_balancing but available configs are less than 2 display warning
has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2
provider_models.append(
ModelWithProviderEntity(
model=m.model,
label=m.label,
model_type=m.model_type,
features=m.features,
fetch_from=m.fetch_from,
model_properties=m.model_properties,
deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=status,
load_balancing_enabled=load_balancing_enabled,
has_invalid_load_balancing_configs=has_invalid_load_balancing_configs,
)
)
# custom models
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type not in model_types:
continue
if model_configuration.unadded_to_model_list:
continue
if model and model != model_configuration.model:
continue
try:
custom_model_schema = self.get_model_schema(
model_type=model_configuration.model_type,
model=model_configuration.model,
credentials=model_configuration.credentials,
)
except Exception as ex:
logger.warning("get custom model schema failed, %s", ex)
continue
if not custom_model_schema:
continue
status = ModelStatus.ACTIVE
load_balancing_enabled = False
has_invalid_load_balancing_configs = False
if (
custom_model_schema.model_type in model_setting_map
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
):
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
if model_setting.enabled is False:
status = ModelStatus.DISABLED
custom_model_lb_configs = [
config
for config in model_setting.load_balancing_configs
if config.credential_source_type != "provider"
]
load_balancing_enabled = model_setting.load_balancing_enabled
# when the user enable load_balancing but available configs are less than 2 display warning
has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2
if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials:
status = ModelStatus.CREDENTIAL_REMOVED
provider_models.append(
ModelWithProviderEntity(
model=custom_model_schema.model,
label=custom_model_schema.label,
model_type=custom_model_schema.model_type,
features=custom_model_schema.features,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=status,
load_balancing_enabled=load_balancing_enabled,
has_invalid_load_balancing_configs=has_invalid_load_balancing_configs,
)
)
return provider_models
class ProviderConfigurations(BaseModel):
"""
Model class for provider configuration dict.
"""
tenant_id: str
configurations: dict[str, ProviderConfiguration] = Field(default_factory=dict)
def __init__(self, tenant_id: str):
super().__init__(tenant_id=tenant_id)
def get_models(
self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
) -> list[ModelWithProviderEntity]:
"""
Get available models.
If preferred provider type is `system`:
Get the current **system mode** if provider supported,
if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
If there is no model configured in custom mode, it is treated as no_configure.
system > custom > no_configure
If preferred provider type is `custom`:
If custom credentials are configured, it is treated as custom mode.
Otherwise, get the current **system mode** if supported,
If all system modes are not available (no quota), it is treated as no_configure.
custom > system > no_configure
If real mode is `system`, use system credentials to get models,
paid quotas > provider free quotas > system free quotas
include pre-defined models (exclude GPT-4, status marked as `no_permission`).
If real mode is `custom`, use workspace custom credentials to get models,
include pre-defined models, custom models(manual append).
If real mode is `no_configure`, only return pre-defined models from `model runtime`.
(model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
model status marked as `active` is available.
:param provider: provider name
:param model_type: model type
:param only_active: only active models
:return:
"""
all_models = []
for provider_configuration in self.values():
if provider and provider_configuration.provider.provider != provider:
continue
all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
return all_models
def to_list(self) -> list[ProviderConfiguration]:
"""
Convert to list.
:return:
"""
return list(self.values())
def __getitem__(self, key):
if "/" not in key:
key = str(ModelProviderID(key))
return self.configurations[key]
def __setitem__(self, key, value):
self.configurations[key] = value
def __iter__(self):
return iter(self.configurations)
def values(self) -> Iterator[ProviderConfiguration]:
return iter(self.configurations.values())
def get(self, key, default=None) -> ProviderConfiguration | None:
if "/" not in key:
key = str(ModelProviderID(key))
return self.configurations.get(key, default) # type: ignore
class ProviderModelBundle(BaseModel):
"""
Provider model bundle.
"""
configuration: ProviderConfiguration
model_type_instance: AIModel
# pydantic configs
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())