chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -47,6 +47,7 @@ class ProviderConfiguration(BaseModel):
"""
Model class for provider configuration.
"""
tenant_id: str
provider: ProviderEntity
preferred_provider_type: ProviderType
@@ -67,9 +68,13 @@ class ProviderConfiguration(BaseModel):
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
if (any(len(quota_configuration.restrict_models) > 0
for quota_configuration in self.system_configuration.quota_configurations)
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
if (
any(
len(quota_configuration.restrict_models) > 0
for quota_configuration in self.system_configuration.quota_configurations
)
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
):
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
@@ -83,10 +88,9 @@ class ProviderConfiguration(BaseModel):
if self.model_settings:
# check if model is disabled by admin
for model_setting in self.model_settings:
if (model_setting.model_type == model_type
and model_setting.model == model):
if model_setting.model_type == model_type and model_setting.model == model:
if not model_setting.enabled:
raise ValueError(f'Model {model} is disabled.')
raise ValueError(f"Model {model} is disabled.")
if self.using_provider_type == ProviderType.SYSTEM:
restrict_models = []
@@ -99,10 +103,12 @@ class ProviderConfiguration(BaseModel):
copy_credentials = self.system_configuration.credentials.copy()
if restrict_models:
for restrict_model in restrict_models:
if (restrict_model.model_type == model_type
and restrict_model.model == model
and restrict_model.base_model_name):
copy_credentials['base_model_name'] = restrict_model.base_model_name
if (
restrict_model.model_type == model_type
and restrict_model.model == model
and restrict_model.base_model_name
):
copy_credentials["base_model_name"] = restrict_model.base_model_name
return copy_credentials
else:
@@ -128,20 +134,21 @@ class ProviderConfiguration(BaseModel):
current_quota_type = self.system_configuration.current_quota_type
current_quota_configuration = next(
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
None
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
)
return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
SystemConfigurationStatus.QUOTA_EXCEEDED
return (
SystemConfigurationStatus.ACTIVE
if current_quota_configuration.is_valid
else SystemConfigurationStatus.QUOTA_EXCEEDED
)
def is_custom_configuration_available(self) -> bool:
"""
Check custom configuration available.
:return:
"""
return (self.custom_configuration.provider is not None
or len(self.custom_configuration.models) > 0)
return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
"""
@@ -161,7 +168,8 @@ class ProviderConfiguration(BaseModel):
return self.obfuscated_credentials(
credentials=credentials,
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema else []
if self.provider.provider_credential_schema
else [],
)
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
@@ -171,17 +179,21 @@ class ProviderConfiguration(BaseModel):
:return:
"""
# get provider
provider_record = db.session.query(Provider) \
provider_record = (
db.session.query(Provider)
.filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_name == self.provider.provider,
Provider.provider_type == ProviderType.CUSTOM.value
).first()
Provider.tenant_id == self.tenant_id,
Provider.provider_name == self.provider.provider,
Provider.provider_type == ProviderType.CUSTOM.value,
)
.first()
)
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema else []
if self.provider.provider_credential_schema
else []
)
if provider_record:
@@ -189,9 +201,7 @@ class ProviderConfiguration(BaseModel):
# fix origin data
if provider_record.encrypted_config:
if not provider_record.encrypted_config.startswith("{"):
original_credentials = {
"openai_api_key": provider_record.encrypted_config
}
original_credentials = {"openai_api_key": provider_record.encrypted_config}
else:
original_credentials = json.loads(provider_record.encrypted_config)
else:
@@ -207,8 +217,7 @@ class ProviderConfiguration(BaseModel):
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider,
credentials=credentials
provider=self.provider.provider, credentials=credentials
)
for key, value in credentials.items():
@@ -239,15 +248,13 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(credentials),
is_valid=True
is_valid=True,
)
db.session.add(provider_record)
db.session.commit()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER
tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER
)
provider_model_credentials_cache.delete()
@@ -260,12 +267,15 @@ class ProviderConfiguration(BaseModel):
:return:
"""
# get provider
provider_record = db.session.query(Provider) \
provider_record = (
db.session.query(Provider)
.filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_name == self.provider.provider,
Provider.provider_type == ProviderType.CUSTOM.value
).first()
Provider.tenant_id == self.tenant_id,
Provider.provider_name == self.provider.provider,
Provider.provider_type == ProviderType.CUSTOM.value,
)
.first()
)
# delete provider
if provider_record:
@@ -277,13 +287,14 @@ class ProviderConfiguration(BaseModel):
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER
cache_type=ProviderCredentialsCacheType.PROVIDER,
)
provider_model_credentials_cache.delete()
def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
-> Optional[dict]:
def get_custom_model_credentials(
self, model_type: ModelType, model: str, obfuscated: bool = False
) -> Optional[dict]:
"""
Get custom model credentials.
@@ -305,13 +316,15 @@ class ProviderConfiguration(BaseModel):
return self.obfuscated_credentials(
credentials=credentials,
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema else []
if self.provider.model_credential_schema
else [],
)
return None
def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
-> tuple[ProviderModel, dict]:
def custom_model_credentials_validate(
self, model_type: ModelType, model: str, credentials: dict
) -> tuple[ProviderModel, dict]:
"""
Validate custom model credentials.
@@ -321,24 +334,29 @@ class ProviderConfiguration(BaseModel):
:return:
"""
# get provider model
provider_model_record = db.session.query(ProviderModel) \
provider_model_record = (
db.session.query(ProviderModel)
.filter(
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name == self.provider.provider,
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type()
).first()
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name == self.provider.provider,
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type(),
)
.first()
)
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema else []
if self.provider.model_credential_schema
else []
)
if provider_model_record:
try:
original_credentials = json.loads(
provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
original_credentials = (
json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
)
except JSONDecodeError:
original_credentials = {}
@@ -350,10 +368,7 @@ class ProviderConfiguration(BaseModel):
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider,
model_type=model_type,
model=model,
credentials=credentials
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
for key, value in credentials.items():
@@ -388,7 +403,7 @@ class ProviderConfiguration(BaseModel):
model_name=model,
model_type=model_type.to_origin_model_type(),
encrypted_config=json.dumps(credentials),
is_valid=True
is_valid=True,
)
db.session.add(provider_model_record)
db.session.commit()
@@ -396,7 +411,7 @@ class ProviderConfiguration(BaseModel):
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL
cache_type=ProviderCredentialsCacheType.MODEL,
)
provider_model_credentials_cache.delete()
@@ -409,13 +424,16 @@ class ProviderConfiguration(BaseModel):
:return:
"""
# get provider model
provider_model_record = db.session.query(ProviderModel) \
provider_model_record = (
db.session.query(ProviderModel)
.filter(
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name == self.provider.provider,
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type()
).first()
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name == self.provider.provider,
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type(),
)
.first()
)
# delete provider model
if provider_model_record:
@@ -425,7 +443,7 @@ class ProviderConfiguration(BaseModel):
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL
cache_type=ProviderCredentialsCacheType.MODEL,
)
provider_model_credentials_cache.delete()
@@ -437,13 +455,16 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_setting = db.session.query(ProviderModelSetting) \
model_setting = (
db.session.query(ProviderModelSetting)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model
).first()
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
.first()
)
if model_setting:
model_setting.enabled = True
@@ -455,7 +476,7 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
enabled=True
enabled=True,
)
db.session.add(model_setting)
db.session.commit()
@@ -469,13 +490,16 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_setting = db.session.query(ProviderModelSetting) \
model_setting = (
db.session.query(ProviderModelSetting)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model
).first()
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
.first()
)
if model_setting:
model_setting.enabled = False
@@ -487,7 +511,7 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
enabled=False
enabled=False,
)
db.session.add(model_setting)
db.session.commit()
@@ -501,13 +525,16 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
return db.session.query(ProviderModelSetting) \
return (
db.session.query(ProviderModelSetting)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model
).first()
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
.first()
)
def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
"""
@@ -516,24 +543,30 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \
load_balancing_config_count = (
db.session.query(LoadBalancingModelConfig)
.filter(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model
).count()
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
.count()
)
if load_balancing_config_count <= 1:
raise ValueError('Model load balancing configuration must be more than 1.')
raise ValueError("Model load balancing configuration must be more than 1.")
model_setting = db.session.query(ProviderModelSetting) \
model_setting = (
db.session.query(ProviderModelSetting)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model
).first()
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
.first()
)
if model_setting:
model_setting.load_balancing_enabled = True
@@ -545,7 +578,7 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
load_balancing_enabled=True
load_balancing_enabled=True,
)
db.session.add(model_setting)
db.session.commit()
@@ -559,13 +592,16 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_setting = db.session.query(ProviderModelSetting) \
model_setting = (
db.session.query(ProviderModelSetting)
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model
).first()
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
.first()
)
if model_setting:
model_setting.load_balancing_enabled = False
@@ -577,7 +613,7 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
load_balancing_enabled=False
load_balancing_enabled=False,
)
db.session.add(model_setting)
db.session.commit()
@@ -617,11 +653,14 @@ class ProviderConfiguration(BaseModel):
return
# get preferred provider
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
preferred_model_provider = (
db.session.query(TenantPreferredModelProvider)
.filter(
TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name == self.provider.provider
).first()
TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name == self.provider.provider,
)
.first()
)
if preferred_model_provider:
preferred_model_provider.preferred_provider_type = provider_type.value
@@ -629,7 +668,7 @@ class ProviderConfiguration(BaseModel):
preferred_model_provider = TenantPreferredModelProvider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
preferred_provider_type=provider_type.value
preferred_provider_type=provider_type.value,
)
db.session.add(preferred_model_provider)
@@ -658,9 +697,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(
credential_form_schemas
)
credential_secret_variables = self.extract_secret_variables(credential_form_schemas)
# Obfuscate provider credentials
copy_credentials = credentials.copy()
@@ -670,9 +707,9 @@ class ProviderConfiguration(BaseModel):
return copy_credentials
def get_provider_model(self, model_type: ModelType,
model: str,
only_active: bool = False) -> Optional[ModelWithProviderEntity]:
def get_provider_model(
self, model_type: ModelType, model: str, only_active: bool = False
) -> Optional[ModelWithProviderEntity]:
"""
Get provider model.
:param model_type: model type
@@ -688,8 +725,9 @@ class ProviderConfiguration(BaseModel):
return None
def get_provider_models(self, model_type: Optional[ModelType] = None,
only_active: bool = False) -> list[ModelWithProviderEntity]:
def get_provider_models(
self, model_type: Optional[ModelType] = None, only_active: bool = False
) -> list[ModelWithProviderEntity]:
"""
Get provider models.
:param model_type: model type
@@ -711,15 +749,11 @@ class ProviderConfiguration(BaseModel):
if self.using_provider_type == ProviderType.SYSTEM:
provider_models = self._get_system_provider_models(
model_types=model_types,
provider_instance=provider_instance,
model_setting_map=model_setting_map
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
)
else:
provider_models = self._get_custom_provider_models(
model_types=model_types,
provider_instance=provider_instance,
model_setting_map=model_setting_map
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
)
if only_active:
@@ -728,11 +762,12 @@ class ProviderConfiguration(BaseModel):
# resort provider_models
return sorted(provider_models, key=lambda x: x.model_type.value)
def _get_system_provider_models(self,
model_types: list[ModelType],
provider_instance: ModelProvider,
model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
-> list[ModelWithProviderEntity]:
def _get_system_provider_models(
self,
model_types: list[ModelType],
provider_instance: ModelProvider,
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
) -> list[ModelWithProviderEntity]:
"""
Get system provider models.
@@ -760,7 +795,7 @@ class ProviderConfiguration(BaseModel):
model_properties=m.model_properties,
deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=status
status=status,
)
)
@@ -783,23 +818,20 @@ class ProviderConfiguration(BaseModel):
if should_use_custom_model:
if original_provider_configurate_methods[self.provider.provider] == [
ConfigurateMethod.CUSTOMIZABLE_MODEL]:
ConfigurateMethod.CUSTOMIZABLE_MODEL
]:
# only customizable model
for restrict_model in restrict_models:
copy_credentials = self.system_configuration.credentials.copy()
if restrict_model.base_model_name:
copy_credentials['base_model_name'] = restrict_model.base_model_name
copy_credentials["base_model_name"] = restrict_model.base_model_name
try:
custom_model_schema = (
provider_instance.get_model_instance(restrict_model.model_type)
.get_customizable_model_schema_from_credentials(
restrict_model.model,
copy_credentials
)
)
custom_model_schema = provider_instance.get_model_instance(
restrict_model.model_type
).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
except Exception as ex:
logger.warning(f'get custom model schema failed, {ex}')
logger.warning(f"get custom model schema failed, {ex}")
continue
if not custom_model_schema:
@@ -809,8 +841,10 @@ class ProviderConfiguration(BaseModel):
continue
status = ModelStatus.ACTIVE
if (custom_model_schema.model_type in model_setting_map
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
if (
custom_model_schema.model_type in model_setting_map
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
):
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
if model_setting.enabled is False:
status = ModelStatus.DISABLED
@@ -825,7 +859,7 @@ class ProviderConfiguration(BaseModel):
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=status
status=status,
)
)
@@ -839,11 +873,12 @@ class ProviderConfiguration(BaseModel):
return provider_models
def _get_custom_provider_models(self,
model_types: list[ModelType],
provider_instance: ModelProvider,
model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
-> list[ModelWithProviderEntity]:
def _get_custom_provider_models(
self,
model_types: list[ModelType],
provider_instance: ModelProvider,
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
) -> list[ModelWithProviderEntity]:
"""
Get custom provider models.
@@ -885,7 +920,7 @@ class ProviderConfiguration(BaseModel):
deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=status,
load_balancing_enabled=load_balancing_enabled
load_balancing_enabled=load_balancing_enabled,
)
)
@@ -895,15 +930,13 @@ class ProviderConfiguration(BaseModel):
continue
try:
custom_model_schema = (
provider_instance.get_model_instance(model_configuration.model_type)
.get_customizable_model_schema_from_credentials(
model_configuration.model,
model_configuration.credentials
)
custom_model_schema = provider_instance.get_model_instance(
model_configuration.model_type
).get_customizable_model_schema_from_credentials(
model_configuration.model, model_configuration.credentials
)
except Exception as ex:
logger.warning(f'get custom model schema failed, {ex}')
logger.warning(f"get custom model schema failed, {ex}")
continue
if not custom_model_schema:
@@ -911,8 +944,10 @@ class ProviderConfiguration(BaseModel):
status = ModelStatus.ACTIVE
load_balancing_enabled = False
if (custom_model_schema.model_type in model_setting_map
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
if (
custom_model_schema.model_type in model_setting_map
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
):
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
if model_setting.enabled is False:
status = ModelStatus.DISABLED
@@ -931,7 +966,7 @@ class ProviderConfiguration(BaseModel):
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=status,
load_balancing_enabled=load_balancing_enabled
load_balancing_enabled=load_balancing_enabled,
)
)
@@ -942,17 +977,16 @@ class ProviderConfigurations(BaseModel):
"""
Model class for provider configuration dict.
"""
tenant_id: str
configurations: dict[str, ProviderConfiguration] = {}
def __init__(self, tenant_id: str):
super().__init__(tenant_id=tenant_id)
def get_models(self,
provider: Optional[str] = None,
model_type: Optional[ModelType] = None,
only_active: bool = False) \
-> list[ModelWithProviderEntity]:
def get_models(
self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
) -> list[ModelWithProviderEntity]:
"""
Get available models.
@@ -1019,10 +1053,10 @@ class ProviderModelBundle(BaseModel):
"""
Provider model bundle.
"""
configuration: ProviderConfiguration
provider_instance: ModelProvider
model_type_instance: AIModel
# pydantic configs
model_config = ConfigDict(arbitrary_types_allowed=True,
protected_namespaces=())
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())