import json import logging import re from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError from typing import Optional from pydantic import BaseModel, ConfigDict, Field from sqlalchemy import func, select from sqlalchemy.orm import Session from constants import HIDDEN_VALUE from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity from core.entities.provider_entities import ( CustomConfiguration, ModelSettings, SystemConfiguration, SystemConfigurationStatus, ) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from core.model_runtime.entities.provider_entities import ( ConfigurateMethod, CredentialFormSchema, FormType, ProviderEntity, ) from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.plugin.entities.plugin import ModelProviderID from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.provider import ( LoadBalancingModelConfig, Provider, ProviderCredential, ProviderModel, ProviderModelCredential, ProviderModelSetting, ProviderType, TenantPreferredModelProvider, ) logger = logging.getLogger(__name__) original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {} class ProviderConfiguration(BaseModel): """ Provider configuration entity for managing model provider settings. This class handles: - Provider credentials CRUD and switch - Custom Model credentials CRUD and switch - System vs custom provider switching - Load balancing configurations - Model enablement/disablement TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified """ tenant_id: str provider: ProviderEntity preferred_provider_type: ProviderType using_provider_type: ProviderType system_configuration: SystemConfiguration custom_configuration: CustomConfiguration model_settings: list[ModelSettings] # pydantic configs model_config = ConfigDict(protected_namespaces=()) def __init__(self, **data): super().__init__(**data) if self.provider.provider not in original_provider_configurate_methods: original_provider_configurate_methods[self.provider.provider] = [] for configurate_method in self.provider.configurate_methods: original_provider_configurate_methods[self.provider.provider].append(configurate_method) if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: if ( any( len(quota_configuration.restrict_models) > 0 for quota_configuration in self.system_configuration.quota_configurations ) and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods ): self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: """ Get current credentials. :param model_type: model type :param model: model name :return: """ if self.model_settings: # check if model is disabled by admin for model_setting in self.model_settings: if model_setting.model_type == model_type and model_setting.model == model: if not model_setting.enabled: raise ValueError(f"Model {model} is disabled.") if self.using_provider_type == ProviderType.SYSTEM: restrict_models = [] for quota_configuration in self.system_configuration.quota_configurations: if self.system_configuration.current_quota_type != quota_configuration.quota_type: continue restrict_models = quota_configuration.restrict_models copy_credentials = ( self.system_configuration.credentials.copy() if self.system_configuration.credentials else {} ) if restrict_models: for restrict_model in restrict_models: if ( restrict_model.model_type == model_type and restrict_model.model == model and restrict_model.base_model_name ): copy_credentials["base_model_name"] = restrict_model.base_model_name return copy_credentials else: credentials = None if self.custom_configuration.models: for model_configuration in self.custom_configuration.models: if model_configuration.model_type == model_type and model_configuration.model == model: credentials = model_configuration.credentials break if not credentials and self.custom_configuration.provider: credentials = self.custom_configuration.provider.credentials return credentials def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]: """ Get system configuration status. :return: """ if self.system_configuration.enabled is False: return SystemConfigurationStatus.UNSUPPORTED current_quota_type = self.system_configuration.current_quota_type current_quota_configuration = next( (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None ) if current_quota_configuration is None: return None if not current_quota_configuration: return SystemConfigurationStatus.UNSUPPORTED return ( SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else SystemConfigurationStatus.QUOTA_EXCEEDED ) def is_custom_configuration_available(self) -> bool: """ Check custom configuration available. :return: """ has_provider_credentials = ( self.custom_configuration.provider is not None and len(self.custom_configuration.provider.available_credentials) > 0 ) has_model_configurations = len(self.custom_configuration.models) > 0 return has_provider_credentials or has_model_configurations def _get_provider_record(self, session: Session) -> Provider | None: """ Get custom provider record. """ # get provider model_provider_id = ModelProviderID(self.provider.provider) provider_names = [self.provider.provider] if model_provider_id.is_langgenius(): provider_names.append(model_provider_id.provider_name) stmt = select(Provider).where( Provider.tenant_id == self.tenant_id, Provider.provider_type == ProviderType.CUSTOM.value, Provider.provider_name.in_(provider_names), ) return session.execute(stmt).scalar_one_or_none() def _get_specific_provider_credential(self, credential_id: str) -> dict | None: """ Get a specific provider credential by ID. :param credential_id: Credential ID :return: """ # Extract secret variables from provider credential schema credential_secret_variables = self.extract_secret_variables( self.provider.provider_credential_schema.credential_form_schemas if self.provider.provider_credential_schema else [] ) with Session(db.engine) as session: # Prefer the actual provider record name if exists (to handle aliased provider names) provider_record = self._get_provider_record(session) provider_name = provider_record.provider_name if provider_record else self.provider.provider stmt = select(ProviderCredential).where( ProviderCredential.id == credential_id, ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name == provider_name, ) credential = session.execute(stmt).scalar_one_or_none() if not credential or not credential.encrypted_config: raise ValueError(f"Credential with id {credential_id} not found.") try: credentials = json.loads(credential.encrypted_config) except JSONDecodeError: credentials = {} # Decrypt secret variables for key in credential_secret_variables: if key in credentials and credentials[key] is not None: try: credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key]) except Exception: pass return self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas if self.provider.provider_credential_schema else [], ) def _check_provider_credential_name_exists( self, credential_name: str, session: Session, exclude_id: str | None = None ) -> bool: """ not allowed same name when create or update a credential """ stmt = select(ProviderCredential.id).where( ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name == self.provider.provider, ProviderCredential.credential_name == credential_name, ) if exclude_id: stmt = stmt.where(ProviderCredential.id != exclude_id) return session.execute(stmt).scalar_one_or_none() is not None def get_provider_credential(self, credential_id: str | None = None) -> dict | None: """ Get provider credentials. :param credential_id: if provided, return the specified credential :return: """ if credential_id: return self._get_specific_provider_credential(credential_id) # Default behavior: return current active provider credentials credentials = self.custom_configuration.provider.credentials if self.custom_configuration.provider else {} return self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas if self.provider.provider_credential_schema else [], ) def validate_provider_credentials( self, credentials: dict, credential_id: str = "", session: Session | None = None ) -> dict: """ Validate custom credentials. :param credentials: provider credentials :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate :param session: optional database session :return: """ def _validate(s: Session) -> dict: # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( self.provider.provider_credential_schema.credential_form_schemas if self.provider.provider_credential_schema else [] ) if credential_id: try: stmt = select(ProviderCredential).where( ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name == self.provider.provider, ProviderCredential.id == credential_id, ) credential_record = s.execute(stmt).scalar_one_or_none() # fix origin data if credential_record and credential_record.encrypted_config: if not credential_record.encrypted_config.startswith("{"): original_credentials = {"openai_api_key": credential_record.encrypted_config} else: original_credentials = json.loads(credential_record.encrypted_config) else: original_credentials = {} except JSONDecodeError: original_credentials = {} # encrypt credentials for key, value in credentials.items(): if key in provider_credential_secret_variables: # if send [__HIDDEN__] in secret input, it will be same as original value if value == HIDDEN_VALUE and key in original_credentials: credentials[key] = encrypter.decrypt_token( tenant_id=self.tenant_id, token=original_credentials[key] ) model_provider_factory = ModelProviderFactory(self.tenant_id) validated_credentials = model_provider_factory.provider_credentials_validate( provider=self.provider.provider, credentials=credentials ) for key, value in validated_credentials.items(): if key in provider_credential_secret_variables: validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) return validated_credentials if session: return _validate(session) else: with Session(db.engine) as new_session: return _validate(new_session) def _generate_provider_credential_name(self, session) -> str: """ Generate a unique credential name for provider. :return: credential name """ return self._generate_next_api_key_name( session=session, query_factory=lambda: select(ProviderCredential).where( ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name == self.provider.provider, ), ) def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str: """ Generate a unique credential name for custom model. :return: credential name """ return self._generate_next_api_key_name( session=session, query_factory=lambda: select(ProviderModelCredential).where( ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ), ) def _generate_next_api_key_name(self, session, query_factory) -> str: """ Generate next available API KEY name by finding the highest numbered suffix. :param session: database session :param query_factory: function that returns the SQLAlchemy query :return: next available API KEY name """ try: stmt = query_factory() credential_records = session.execute(stmt).scalars().all() if not credential_records: return "API KEY 1" # Extract numbers from API KEY pattern using list comprehension pattern = re.compile(r"^API KEY\s+(\d+)$") numbers = [ int(match.group(1)) for cr in credential_records if cr.credential_name and (match := pattern.match(cr.credential_name.strip())) ] # Return next sequential number next_number = max(numbers, default=0) + 1 return f"API KEY {next_number}" except Exception as e: logger.warning("Error generating next credential name: %s", str(e)) return "API KEY 1" def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None: """ Add custom provider credentials. :param credentials: provider credentials :param credential_name: credential name :return: """ with Session(db.engine) as session: if credential_name: if self._check_provider_credential_name_exists(credential_name=credential_name, session=session): raise ValueError(f"Credential with name '{credential_name}' already exists.") else: credential_name = self._generate_provider_credential_name(session) credentials = self.validate_provider_credentials(credentials=credentials, session=session) provider_record = self._get_provider_record(session) try: new_record = ProviderCredential( tenant_id=self.tenant_id, provider_name=self.provider.provider, encrypted_config=json.dumps(credentials), credential_name=credential_name, ) session.add(new_record) session.flush() if not provider_record: # If provider record does not exist, create it provider_record = Provider( tenant_id=self.tenant_id, provider_name=self.provider.provider, provider_type=ProviderType.CUSTOM.value, is_valid=True, credential_id=new_record.id, ) session.add(provider_record) provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER, ) provider_model_credentials_cache.delete() self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session) session.commit() except Exception: session.rollback() raise def update_provider_credential( self, credentials: dict, credential_id: str, credential_name: str | None, ) -> None: """ update a saved provider credential (by credential_id). :param credentials: provider credentials :param credential_id: credential id :param credential_name: credential name :return: """ with Session(db.engine) as session: if credential_name and self._check_provider_credential_name_exists( credential_name=credential_name, session=session, exclude_id=credential_id ): raise ValueError(f"Credential with name '{credential_name}' already exists.") credentials = self.validate_provider_credentials( credentials=credentials, credential_id=credential_id, session=session ) provider_record = self._get_provider_record(session) stmt = select(ProviderCredential).where( ProviderCredential.id == credential_id, ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name == self.provider.provider, ) # Get the credential record to update credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: raise ValueError("Credential record not found.") try: # Update credential credential_record.encrypted_config = json.dumps(credentials) credential_record.updated_at = naive_utc_now() if credential_name: credential_record.credential_name = credential_name session.commit() if provider_record and provider_record.credential_id == credential_id: provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER, ) provider_model_credentials_cache.delete() self._update_load_balancing_configs_with_credential( credential_id=credential_id, credential_record=credential_record, credential_source="provider", session=session, ) except Exception: session.rollback() raise def _update_load_balancing_configs_with_credential( self, credential_id: str, credential_record: ProviderCredential | ProviderModelCredential, credential_source: str, session: Session, ) -> None: """ Update load balancing configurations that reference the given credential_id. :param credential_id: credential id :param credential_record: the encrypted_config and credential_name :param credential_source: the credential comes from the provider_credential(`provider`) or the provider_model_credential(`custom_model`) :param session: the database session :return: """ # Find all load balancing configs that use this credential_id stmt = select(LoadBalancingModelConfig).where( LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name == self.provider.provider, LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_source_type == credential_source, ) load_balancing_configs = session.execute(stmt).scalars().all() if not load_balancing_configs: return # Update each load balancing config with the new credentials for lb_config in load_balancing_configs: # Update the encrypted_config with the new credentials lb_config.encrypted_config = credential_record.encrypted_config lb_config.name = credential_record.credential_name lb_config.updated_at = naive_utc_now() # Clear cache for this load balancing config lb_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=lb_config.id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, ) lb_credentials_cache.delete() session.commit() def delete_provider_credential(self, credential_id: str) -> None: """ Delete a saved provider credential (by credential_id). :param credential_id: credential id :return: """ with Session(db.engine) as session: stmt = select(ProviderCredential).where( ProviderCredential.id == credential_id, ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name == self.provider.provider, ) # Get the credential record to update credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: raise ValueError("Credential record not found.") # Check if this credential is used in load balancing configs lb_stmt = select(LoadBalancingModelConfig).where( LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name == self.provider.provider, LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_source_type == "provider", ) lb_configs_using_credential = session.execute(lb_stmt).scalars().all() try: for lb_config in lb_configs_using_credential: lb_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=lb_config.id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, ) lb_credentials_cache.delete() session.delete(lb_config) # Check if this is the currently active credential provider_record = self._get_provider_record(session) # Check available credentials count BEFORE deleting # if this is the last credential, we need to delete the provider record count_stmt = select(func.count(ProviderCredential.id)).where( ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name == self.provider.provider, ) available_credentials_count = session.execute(count_stmt).scalar() or 0 session.delete(credential_record) if provider_record and available_credentials_count <= 1: # If all credentials are deleted, delete the provider record session.delete(provider_record) provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER, ) provider_model_credentials_cache.delete() self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session) elif provider_record and provider_record.credential_id == credential_id: provider_record.credential_id = None provider_record.updated_at = naive_utc_now() provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER, ) provider_model_credentials_cache.delete() self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session) session.commit() except Exception: session.rollback() raise def switch_active_provider_credential(self, credential_id: str) -> None: """ Switch active provider credential (copy the selected one into current active snapshot). :param credential_id: credential id :return: """ with Session(db.engine) as session: stmt = select(ProviderCredential).where( ProviderCredential.id == credential_id, ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name == self.provider.provider, ) credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: raise ValueError("Credential record not found.") provider_record = self._get_provider_record(session) if not provider_record: raise ValueError("Provider record not found.") try: provider_record.credential_id = credential_record.id provider_record.updated_at = naive_utc_now() session.commit() provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER, ) provider_model_credentials_cache.delete() self.switch_preferred_provider_type(ProviderType.CUSTOM, session=session) except Exception: session.rollback() raise def _get_custom_model_record( self, model_type: ModelType, model: str, session: Session, ) -> ProviderModel | None: """ Get custom model credentials. """ # get provider model model_provider_id = ModelProviderID(self.provider.provider) provider_names = [self.provider.provider] if model_provider_id.is_langgenius(): provider_names.append(model_provider_id.provider_name) stmt = select(ProviderModel).where( ProviderModel.tenant_id == self.tenant_id, ProviderModel.provider_name.in_(provider_names), ProviderModel.model_name == model, ProviderModel.model_type == model_type.to_origin_model_type(), ) return session.execute(stmt).scalar_one_or_none() def _get_specific_custom_model_credential( self, model_type: ModelType, model: str, credential_id: str ) -> dict | None: """ Get a specific provider credential by ID. :param credential_id: Credential ID :return: """ model_credential_secret_variables = self.extract_secret_variables( self.provider.model_credential_schema.credential_form_schemas if self.provider.model_credential_schema else [] ) with Session(db.engine) as session: stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record or not credential_record.encrypted_config: raise ValueError(f"Credential with id {credential_id} not found.") try: credentials = json.loads(credential_record.encrypted_config) except JSONDecodeError: credentials = {} # Decrypt secret variables for key in model_credential_secret_variables: if key in credentials and credentials[key] is not None: try: credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key]) except Exception: pass current_credential_id = credential_record.id current_credential_name = credential_record.credential_name credentials = self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas if self.provider.model_credential_schema else [], ) return { "current_credential_id": current_credential_id, "current_credential_name": current_credential_name, "credentials": credentials, } def _check_custom_model_credential_name_exists( self, model_type: ModelType, model: str, credential_name: str, session: Session, exclude_id: str | None = None ) -> bool: """ not allowed same name when create or update a credential """ stmt = select(ProviderModelCredential).where( ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.credential_name == credential_name, ) if exclude_id: stmt = stmt.where(ProviderModelCredential.id != exclude_id) return session.execute(stmt).scalar_one_or_none() is not None def get_custom_model_credential( self, model_type: ModelType, model: str, credential_id: str | None ) -> Optional[dict]: """ Get custom model credentials. :param model_type: model type :param model: model name :return: """ # If credential_id is provided, return the specific credential if credential_id: return self._get_specific_custom_model_credential( model_type=model_type, model=model, credential_id=credential_id ) for model_configuration in self.custom_configuration.models: if ( model_configuration.model_type == model_type and model_configuration.model == model and model_configuration.credentials ): current_credential_id = model_configuration.current_credential_id current_credential_name = model_configuration.current_credential_name credentials = self.obfuscated_credentials( credentials=model_configuration.credentials, credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas if self.provider.model_credential_schema else [], ) return { "current_credential_id": current_credential_id, "current_credential_name": current_credential_name, "credentials": credentials, } return None def validate_custom_model_credentials( self, model_type: ModelType, model: str, credentials: dict, credential_id: str = "", session: Session | None = None, ) -> dict: """ Validate custom model credentials. :param model_type: model type :param model: model name :param credentials: model credentials dict :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate :return: """ def _validate(s: Session) -> dict: # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( self.provider.model_credential_schema.credential_form_schemas if self.provider.model_credential_schema else [] ) if credential_id: try: stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) credential_record = s.execute(stmt).scalar_one_or_none() original_credentials = ( json.loads(credential_record.encrypted_config) if credential_record and credential_record.encrypted_config else {} ) except JSONDecodeError: original_credentials = {} # decrypt credentials for key, value in credentials.items(): if key in provider_credential_secret_variables: # if send [__HIDDEN__] in secret input, it will be same as original value if value == HIDDEN_VALUE and key in original_credentials: credentials[key] = encrypter.decrypt_token( tenant_id=self.tenant_id, token=original_credentials[key] ) model_provider_factory = ModelProviderFactory(self.tenant_id) validated_credentials = model_provider_factory.model_credentials_validate( provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) for key, value in validated_credentials.items(): if key in provider_credential_secret_variables: validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) return validated_credentials if session: return _validate(session) else: with Session(db.engine) as new_session: return _validate(new_session) def create_custom_model_credential( self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None ) -> None: """ Create a custom model credential. :param model_type: model type :param model: model name :param credentials: model credentials dict :return: """ with Session(db.engine) as session: if credential_name: if self._check_custom_model_credential_name_exists( model=model, model_type=model_type, credential_name=credential_name, session=session ): raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") else: credential_name = self._generate_custom_model_credential_name( model=model, model_type=model_type, session=session ) # validate custom model config credentials = self.validate_custom_model_credentials( model_type=model_type, model=model, credentials=credentials, session=session ) provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) try: credential = ProviderModelCredential( tenant_id=self.tenant_id, provider_name=self.provider.provider, model_name=model, model_type=model_type.to_origin_model_type(), encrypted_config=json.dumps(credentials), credential_name=credential_name, ) session.add(credential) session.flush() # save provider model if not provider_model_record: provider_model_record = ProviderModel( tenant_id=self.tenant_id, provider_name=self.provider.provider, model_name=model, model_type=model_type.to_origin_model_type(), credential_id=credential.id, is_valid=True, ) session.add(provider_model_record) session.commit() provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL, ) provider_model_credentials_cache.delete() except Exception: session.rollback() raise def update_custom_model_credential( self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str ) -> None: """ Update a custom model credential. :param model_type: model type :param model: model name :param credentials: model credentials dict :param credential_name: credential name :param credential_id: credential id :return: """ with Session(db.engine) as session: if credential_name and self._check_custom_model_credential_name_exists( model=model, model_type=model_type, credential_name=credential_name, session=session, exclude_id=credential_id, ): raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") # validate custom model config credentials = self.validate_custom_model_credentials( model_type=model_type, model=model, credentials=credentials, credential_id=credential_id, session=session, ) provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: raise ValueError("Credential record not found.") try: # Update credential credential_record.encrypted_config = json.dumps(credentials) credential_record.updated_at = naive_utc_now() if credential_name: credential_record.credential_name = credential_name session.commit() if provider_model_record and provider_model_record.credential_id == credential_id: provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL, ) provider_model_credentials_cache.delete() self._update_load_balancing_configs_with_credential( credential_id=credential_id, credential_record=credential_record, credential_source="custom_model", session=session, ) except Exception: session.rollback() raise def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None: """ Delete a saved provider credential (by credential_id). :param credential_id: credential id :return: """ with Session(db.engine) as session: stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: raise ValueError("Credential record not found.") lb_stmt = select(LoadBalancingModelConfig).where( LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name == self.provider.provider, LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_source_type == "custom_model", ) lb_configs_using_credential = session.execute(lb_stmt).scalars().all() try: for lb_config in lb_configs_using_credential: lb_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=lb_config.id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, ) lb_credentials_cache.delete() session.delete(lb_config) # Check if this is the currently active credential provider_model_record = self._get_custom_model_record(model_type, model, session=session) # Check available credentials count BEFORE deleting # if this is the last credential, we need to delete the custom model record count_stmt = select(func.count(ProviderModelCredential.id)).where( ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) available_credentials_count = session.execute(count_stmt).scalar() or 0 session.delete(credential_record) if provider_model_record and available_credentials_count <= 1: # If all credentials are deleted, delete the custom model record session.delete(provider_model_record) elif provider_model_record and provider_model_record.credential_id == credential_id: provider_model_record.credential_id = None provider_model_record.updated_at = naive_utc_now() provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER, ) provider_model_credentials_cache.delete() session.commit() except Exception: session.rollback() raise def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str) -> None: """ if model list exist this custom model, switch the custom model credential. if model list not exist this custom model, use the credential to add a new custom model record. :param model_type: model type :param model: model name :param credential_id: credential id :return: """ with Session(db.engine) as session: stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: raise ValueError("Credential record not found.") # validate custom model config provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) if not provider_model_record: # create provider model record provider_model_record = ProviderModel( tenant_id=self.tenant_id, provider_name=self.provider.provider, model_name=model, model_type=model_type.to_origin_model_type(), is_valid=True, credential_id=credential_id, ) else: if provider_model_record.credential_id == credential_record.id: raise ValueError("Can't add same credential") provider_model_record.credential_id = credential_record.id provider_model_record.updated_at = naive_utc_now() session.add(provider_model_record) session.commit() def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None: """ switch the custom model credential. :param model_type: model type :param model: model name :param credential_id: credential id :return: """ with Session(db.engine) as session: stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: raise ValueError("Credential record not found.") provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) if not provider_model_record: raise ValueError("The custom model record not found.") provider_model_record.credential_id = credential_record.id provider_model_record.updated_at = naive_utc_now() session.add(provider_model_record) session.commit() def delete_custom_model(self, model_type: ModelType, model: str) -> None: """ Delete custom model. :param model_type: model type :param model: model name :return: """ with Session(db.engine) as session: # get provider model provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) # delete provider model if provider_model_record: session.delete(provider_model_record) session.commit() provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL, ) provider_model_credentials_cache.delete() def _get_provider_model_setting( self, model_type: ModelType, model: str, session: Session ) -> ProviderModelSetting | None: """ Get provider model setting. """ model_provider_id = ModelProviderID(self.provider.provider) provider_names = [self.provider.provider] if model_provider_id.is_langgenius(): provider_names.append(model_provider_id.provider_name) stmt = select(ProviderModelSetting).where( ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.provider_name.in_(provider_names), ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_name == model, ) return session.execute(stmt).scalars().first() def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ Enable model. :param model_type: model type :param model: model name :return: """ with Session(db.engine) as session: model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session) if model_setting: model_setting.enabled = True model_setting.updated_at = naive_utc_now() else: model_setting = ProviderModelSetting( tenant_id=self.tenant_id, provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, enabled=True, ) session.add(model_setting) session.commit() return model_setting def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ Disable model. :param model_type: model type :param model: model name :return: """ with Session(db.engine) as session: model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session) if model_setting: model_setting.enabled = False model_setting.updated_at = naive_utc_now() else: model_setting = ProviderModelSetting( tenant_id=self.tenant_id, provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, enabled=False, ) session.add(model_setting) session.commit() return model_setting def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]: """ Get provider model setting. :param model_type: model type :param model: model name :return: """ with Session(db.engine) as session: return self._get_provider_model_setting(model_type=model_type, model=model, session=session) def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ Enable model load balancing. :param model_type: model type :param model: model name :return: """ model_provider_id = ModelProviderID(self.provider.provider) provider_names = [self.provider.provider] if model_provider_id.is_langgenius(): provider_names.append(model_provider_id.provider_name) with Session(db.engine) as session: stmt = select(func.count(LoadBalancingModelConfig.id)).where( LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(provider_names), LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, ) load_balancing_config_count = session.execute(stmt).scalar() or 0 if load_balancing_config_count <= 1: raise ValueError("Model load balancing configuration must be more than 1.") model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session) if model_setting: model_setting.load_balancing_enabled = True model_setting.updated_at = naive_utc_now() else: model_setting = ProviderModelSetting( tenant_id=self.tenant_id, provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, load_balancing_enabled=True, ) session.add(model_setting) session.commit() return model_setting def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ Disable model load balancing. :param model_type: model type :param model: model name :return: """ with Session(db.engine) as session: model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session) if model_setting: model_setting.load_balancing_enabled = False model_setting.updated_at = naive_utc_now() else: model_setting = ProviderModelSetting( tenant_id=self.tenant_id, provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, load_balancing_enabled=False, ) session.add(model_setting) session.commit() return model_setting def get_model_type_instance(self, model_type: ModelType) -> AIModel: """ Get current model type instance. :param model_type: model type :return: """ model_provider_factory = ModelProviderFactory(self.tenant_id) # Get model instance of LLM return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type) def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None: """ Get model schema """ model_provider_factory = ModelProviderFactory(self.tenant_id) return model_provider_factory.get_model_schema( provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None) -> None: """ Switch preferred provider type. :param provider_type: :return: """ if provider_type == self.preferred_provider_type: return if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled: return def _switch(s: Session) -> None: # get preferred provider model_provider_id = ModelProviderID(self.provider.provider) provider_names = [self.provider.provider] if model_provider_id.is_langgenius(): provider_names.append(model_provider_id.provider_name) stmt = select(TenantPreferredModelProvider).where( TenantPreferredModelProvider.tenant_id == self.tenant_id, TenantPreferredModelProvider.provider_name.in_(provider_names), ) preferred_model_provider = s.execute(stmt).scalars().first() if preferred_model_provider: preferred_model_provider.preferred_provider_type = provider_type.value else: preferred_model_provider = TenantPreferredModelProvider( tenant_id=self.tenant_id, provider_name=self.provider.provider, preferred_provider_type=provider_type.value, ) s.add(preferred_model_provider) s.commit() if session: return _switch(session) else: with Session(db.engine) as session: return _switch(session) def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: """ Extract secret input form variables. :param credential_form_schemas: :return: """ secret_input_form_variables = [] for credential_form_schema in credential_form_schemas: if credential_form_schema.type == FormType.SECRET_INPUT: secret_input_form_variables.append(credential_form_schema.variable) return secret_input_form_variables def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict: """ Obfuscated credentials. :param credentials: credentials :param credential_form_schemas: credential form schemas :return: """ # Get provider credential secret variables credential_secret_variables = self.extract_secret_variables(credential_form_schemas) # Obfuscate provider credentials copy_credentials = credentials.copy() for key, value in copy_credentials.items(): if key in credential_secret_variables: copy_credentials[key] = encrypter.obfuscated_token(value) return copy_credentials def get_provider_model( self, model_type: ModelType, model: str, only_active: bool = False ) -> Optional[ModelWithProviderEntity]: """ Get provider model. :param model_type: model type :param model: model name :param only_active: return active model only :return: """ provider_models = self.get_provider_models(model_type, only_active, model) for provider_model in provider_models: if provider_model.model == model: return provider_model return None def get_provider_models( self, model_type: Optional[ModelType] = None, only_active: bool = False, model: Optional[str] = None ) -> list[ModelWithProviderEntity]: """ Get provider models. :param model_type: model type :param only_active: only active models :param model: model name :return: """ model_provider_factory = ModelProviderFactory(self.tenant_id) provider_schema = model_provider_factory.get_provider_schema(self.provider.provider) model_types: list[ModelType] = [] if model_type: model_types.append(model_type) else: model_types = list(provider_schema.supported_model_types) # Group model settings by model type and model model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict) for model_setting in self.model_settings: model_setting_map[model_setting.model_type][model_setting.model] = model_setting if self.using_provider_type == ProviderType.SYSTEM: provider_models = self._get_system_provider_models( model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map ) else: provider_models = self._get_custom_provider_models( model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map, model=model, ) if only_active: provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE] # resort provider_models # Optimize sorting logic: first sort by provider.position order, then by model_type.value # Get the position list for model types (retrieve only once for better performance) model_type_positions = {} if hasattr(self.provider, "position") and self.provider.position: model_type_positions = self.provider.position def get_sort_key(model: ModelWithProviderEntity): # Get the position list for the current model type positions = model_type_positions.get(model.model_type.value, []) # If the model name is in the position list, use its index for sorting # Otherwise use a large value (list length) to place undefined models at the end position_index = positions.index(model.model) if model.model in positions else len(positions) # Return composite sort key: (model_type value, model position index) return (model.model_type.value, position_index) # Sort using the composite sort key return sorted(provider_models, key=get_sort_key) def _get_system_provider_models( self, model_types: Sequence[ModelType], provider_schema: ProviderEntity, model_setting_map: dict[ModelType, dict[str, ModelSettings]], ) -> list[ModelWithProviderEntity]: """ Get system provider models. :param model_types: model types :param provider_schema: provider schema :param model_setting_map: model setting map :return: """ provider_models = [] for model_type in model_types: for m in provider_schema.models: if m.model_type != model_type: continue status = ModelStatus.ACTIVE if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: model_setting = model_setting_map[m.model_type][m.model] if model_setting.enabled is False: status = ModelStatus.DISABLED provider_models.append( ModelWithProviderEntity( model=m.model, label=m.label, model_type=m.model_type, features=m.features, fetch_from=m.fetch_from, model_properties=m.model_properties, deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), status=status, ) ) if self.provider.provider not in original_provider_configurate_methods: original_provider_configurate_methods[self.provider.provider] = [] for configurate_method in provider_schema.configurate_methods: original_provider_configurate_methods[self.provider.provider].append(configurate_method) should_use_custom_model = False if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: should_use_custom_model = True for quota_configuration in self.system_configuration.quota_configurations: if self.system_configuration.current_quota_type != quota_configuration.quota_type: continue restrict_models = quota_configuration.restrict_models if len(restrict_models) == 0: break if should_use_custom_model: if original_provider_configurate_methods[self.provider.provider] == [ ConfigurateMethod.CUSTOMIZABLE_MODEL ]: # only customizable model for restrict_model in restrict_models: copy_credentials = ( self.system_configuration.credentials.copy() if self.system_configuration.credentials else {} ) if restrict_model.base_model_name: copy_credentials["base_model_name"] = restrict_model.base_model_name try: custom_model_schema = self.get_model_schema( model_type=restrict_model.model_type, model=restrict_model.model, credentials=copy_credentials, ) except Exception as ex: logger.warning("get custom model schema failed, %s", ex) continue if not custom_model_schema: continue if custom_model_schema.model_type not in model_types: continue status = ModelStatus.ACTIVE if ( custom_model_schema.model_type in model_setting_map and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] ): model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] if model_setting.enabled is False: status = ModelStatus.DISABLED provider_models.append( ModelWithProviderEntity( model=custom_model_schema.model, label=custom_model_schema.label, model_type=custom_model_schema.model_type, features=custom_model_schema.features, fetch_from=FetchFrom.PREDEFINED_MODEL, model_properties=custom_model_schema.model_properties, deprecated=custom_model_schema.deprecated, provider=SimpleModelProviderEntity(self.provider), status=status, ) ) # if llm name not in restricted llm list, remove it restrict_model_names = [rm.model for rm in restrict_models] for model in provider_models: if model.model_type == ModelType.LLM and model.model not in restrict_model_names: model.status = ModelStatus.NO_PERMISSION elif not quota_configuration.is_valid: model.status = ModelStatus.QUOTA_EXCEEDED return provider_models def _get_custom_provider_models( self, model_types: Sequence[ModelType], provider_schema: ProviderEntity, model_setting_map: dict[ModelType, dict[str, ModelSettings]], model: Optional[str] = None, ) -> list[ModelWithProviderEntity]: """ Get custom provider models. :param model_types: model types :param provider_schema: provider schema :param model_setting_map: model setting map :return: """ provider_models = [] credentials = None if self.custom_configuration.provider: credentials = self.custom_configuration.provider.credentials for model_type in model_types: if model_type not in self.provider.supported_model_types: continue for m in provider_schema.models: if m.model_type != model_type: continue status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE load_balancing_enabled = False has_invalid_load_balancing_configs = False if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: model_setting = model_setting_map[m.model_type][m.model] if model_setting.enabled is False: status = ModelStatus.DISABLED provider_model_lb_configs = [ config for config in model_setting.load_balancing_configs if config.credential_source_type != "custom_model" ] load_balancing_enabled = model_setting.load_balancing_enabled # when the user enable load_balancing but available configs are less than 2 display warning has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2 provider_models.append( ModelWithProviderEntity( model=m.model, label=m.label, model_type=m.model_type, features=m.features, fetch_from=m.fetch_from, model_properties=m.model_properties, deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), status=status, load_balancing_enabled=load_balancing_enabled, has_invalid_load_balancing_configs=has_invalid_load_balancing_configs, ) ) # custom models for model_configuration in self.custom_configuration.models: if model_configuration.model_type not in model_types: continue if model_configuration.unadded_to_model_list: continue if model and model != model_configuration.model: continue try: custom_model_schema = self.get_model_schema( model_type=model_configuration.model_type, model=model_configuration.model, credentials=model_configuration.credentials, ) except Exception as ex: logger.warning("get custom model schema failed, %s", ex) continue if not custom_model_schema: continue status = ModelStatus.ACTIVE load_balancing_enabled = False has_invalid_load_balancing_configs = False if ( custom_model_schema.model_type in model_setting_map and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] ): model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] if model_setting.enabled is False: status = ModelStatus.DISABLED custom_model_lb_configs = [ config for config in model_setting.load_balancing_configs if config.credential_source_type != "provider" ] load_balancing_enabled = model_setting.load_balancing_enabled # when the user enable load_balancing but available configs are less than 2 display warning has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2 if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials: status = ModelStatus.CREDENTIAL_REMOVED provider_models.append( ModelWithProviderEntity( model=custom_model_schema.model, label=custom_model_schema.label, model_type=custom_model_schema.model_type, features=custom_model_schema.features, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties=custom_model_schema.model_properties, deprecated=custom_model_schema.deprecated, provider=SimpleModelProviderEntity(self.provider), status=status, load_balancing_enabled=load_balancing_enabled, has_invalid_load_balancing_configs=has_invalid_load_balancing_configs, ) ) return provider_models class ProviderConfigurations(BaseModel): """ Model class for provider configuration dict. """ tenant_id: str configurations: dict[str, ProviderConfiguration] = Field(default_factory=dict) def __init__(self, tenant_id: str): super().__init__(tenant_id=tenant_id) def get_models( self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False ) -> list[ModelWithProviderEntity]: """ Get available models. If preferred provider type is `system`: Get the current **system mode** if provider supported, if all system modes are not available (no quota), it is considered to be the **custom credential mode**. If there is no model configured in custom mode, it is treated as no_configure. system > custom > no_configure If preferred provider type is `custom`: If custom credentials are configured, it is treated as custom mode. Otherwise, get the current **system mode** if supported, If all system modes are not available (no quota), it is treated as no_configure. custom > system > no_configure If real mode is `system`, use system credentials to get models, paid quotas > provider free quotas > system free quotas include pre-defined models (exclude GPT-4, status marked as `no_permission`). If real mode is `custom`, use workspace custom credentials to get models, include pre-defined models, custom models(manual append). If real mode is `no_configure`, only return pre-defined models from `model runtime`. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`) model status marked as `active` is available. :param provider: provider name :param model_type: model type :param only_active: only active models :return: """ all_models = [] for provider_configuration in self.values(): if provider and provider_configuration.provider.provider != provider: continue all_models.extend(provider_configuration.get_provider_models(model_type, only_active)) return all_models def to_list(self) -> list[ProviderConfiguration]: """ Convert to list. :return: """ return list(self.values()) def __getitem__(self, key): if "/" not in key: key = str(ModelProviderID(key)) return self.configurations[key] def __setitem__(self, key, value): self.configurations[key] = value def __iter__(self): return iter(self.configurations) def values(self) -> Iterator[ProviderConfiguration]: return iter(self.configurations.values()) def get(self, key, default=None) -> ProviderConfiguration | None: if "/" not in key: key = str(ModelProviderID(key)) return self.configurations.get(key, default) # type: ignore class ProviderModelBundle(BaseModel): """ Provider model bundle. """ configuration: ProviderConfiguration model_type_instance: AIModel # pydantic configs model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())