chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -47,6 +47,7 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider configuration.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
provider: ProviderEntity
|
||||
preferred_provider_type: ProviderType
|
||||
@@ -67,9 +68,13 @@ class ProviderConfiguration(BaseModel):
|
||||
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
|
||||
|
||||
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||
if (any(len(quota_configuration.restrict_models) > 0
|
||||
for quota_configuration in self.system_configuration.quota_configurations)
|
||||
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
|
||||
if (
|
||||
any(
|
||||
len(quota_configuration.restrict_models) > 0
|
||||
for quota_configuration in self.system_configuration.quota_configurations
|
||||
)
|
||||
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
|
||||
):
|
||||
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
|
||||
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
|
||||
@@ -83,10 +88,9 @@ class ProviderConfiguration(BaseModel):
|
||||
if self.model_settings:
|
||||
# check if model is disabled by admin
|
||||
for model_setting in self.model_settings:
|
||||
if (model_setting.model_type == model_type
|
||||
and model_setting.model == model):
|
||||
if model_setting.model_type == model_type and model_setting.model == model:
|
||||
if not model_setting.enabled:
|
||||
raise ValueError(f'Model {model} is disabled.')
|
||||
raise ValueError(f"Model {model} is disabled.")
|
||||
|
||||
if self.using_provider_type == ProviderType.SYSTEM:
|
||||
restrict_models = []
|
||||
@@ -99,10 +103,12 @@ class ProviderConfiguration(BaseModel):
|
||||
copy_credentials = self.system_configuration.credentials.copy()
|
||||
if restrict_models:
|
||||
for restrict_model in restrict_models:
|
||||
if (restrict_model.model_type == model_type
|
||||
and restrict_model.model == model
|
||||
and restrict_model.base_model_name):
|
||||
copy_credentials['base_model_name'] = restrict_model.base_model_name
|
||||
if (
|
||||
restrict_model.model_type == model_type
|
||||
and restrict_model.model == model
|
||||
and restrict_model.base_model_name
|
||||
):
|
||||
copy_credentials["base_model_name"] = restrict_model.base_model_name
|
||||
|
||||
return copy_credentials
|
||||
else:
|
||||
@@ -128,20 +134,21 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
current_quota_type = self.system_configuration.current_quota_type
|
||||
current_quota_configuration = next(
|
||||
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
|
||||
None
|
||||
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
|
||||
)
|
||||
|
||||
return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
|
||||
SystemConfigurationStatus.QUOTA_EXCEEDED
|
||||
return (
|
||||
SystemConfigurationStatus.ACTIVE
|
||||
if current_quota_configuration.is_valid
|
||||
else SystemConfigurationStatus.QUOTA_EXCEEDED
|
||||
)
|
||||
|
||||
def is_custom_configuration_available(self) -> bool:
|
||||
"""
|
||||
Check custom configuration available.
|
||||
:return:
|
||||
"""
|
||||
return (self.custom_configuration.provider is not None
|
||||
or len(self.custom_configuration.models) > 0)
|
||||
return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
|
||||
|
||||
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||
"""
|
||||
@@ -161,7 +168,8 @@ class ProviderConfiguration(BaseModel):
|
||||
return self.obfuscated_credentials(
|
||||
credentials=credentials,
|
||||
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
|
||||
if self.provider.provider_credential_schema else []
|
||||
if self.provider.provider_credential_schema
|
||||
else [],
|
||||
)
|
||||
|
||||
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
|
||||
@@ -171,17 +179,21 @@ class ProviderConfiguration(BaseModel):
|
||||
:return:
|
||||
"""
|
||||
# get provider
|
||||
provider_record = db.session.query(Provider) \
|
||||
provider_record = (
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).first()
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.provider_credential_schema.credential_form_schemas
|
||||
if self.provider.provider_credential_schema else []
|
||||
if self.provider.provider_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
if provider_record:
|
||||
@@ -189,9 +201,7 @@ class ProviderConfiguration(BaseModel):
|
||||
# fix origin data
|
||||
if provider_record.encrypted_config:
|
||||
if not provider_record.encrypted_config.startswith("{"):
|
||||
original_credentials = {
|
||||
"openai_api_key": provider_record.encrypted_config
|
||||
}
|
||||
original_credentials = {"openai_api_key": provider_record.encrypted_config}
|
||||
else:
|
||||
original_credentials = json.loads(provider_record.encrypted_config)
|
||||
else:
|
||||
@@ -207,8 +217,7 @@ class ProviderConfiguration(BaseModel):
|
||||
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
||||
|
||||
credentials = model_provider_factory.provider_credentials_validate(
|
||||
provider=self.provider.provider,
|
||||
credentials=credentials
|
||||
provider=self.provider.provider, credentials=credentials
|
||||
)
|
||||
|
||||
for key, value in credentials.items():
|
||||
@@ -239,15 +248,13 @@ class ProviderConfiguration(BaseModel):
|
||||
provider_name=self.provider.provider,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps(credentials),
|
||||
is_valid=True
|
||||
is_valid=True,
|
||||
)
|
||||
db.session.add(provider_record)
|
||||
db.session.commit()
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER
|
||||
tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER
|
||||
)
|
||||
|
||||
provider_model_credentials_cache.delete()
|
||||
@@ -260,12 +267,15 @@ class ProviderConfiguration(BaseModel):
|
||||
:return:
|
||||
"""
|
||||
# get provider
|
||||
provider_record = db.session.query(Provider) \
|
||||
provider_record = (
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).first()
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# delete provider
|
||||
if provider_record:
|
||||
@@ -277,13 +287,14 @@ class ProviderConfiguration(BaseModel):
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
|
||||
provider_model_credentials_cache.delete()
|
||||
|
||||
def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
|
||||
-> Optional[dict]:
|
||||
def get_custom_model_credentials(
|
||||
self, model_type: ModelType, model: str, obfuscated: bool = False
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get custom model credentials.
|
||||
|
||||
@@ -305,13 +316,15 @@ class ProviderConfiguration(BaseModel):
|
||||
return self.obfuscated_credentials(
|
||||
credentials=credentials,
|
||||
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
|
||||
if self.provider.model_credential_schema else []
|
||||
if self.provider.model_credential_schema
|
||||
else [],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
|
||||
-> tuple[ProviderModel, dict]:
|
||||
def custom_model_credentials_validate(
|
||||
self, model_type: ModelType, model: str, credentials: dict
|
||||
) -> tuple[ProviderModel, dict]:
|
||||
"""
|
||||
Validate custom model credentials.
|
||||
|
||||
@@ -321,24 +334,29 @@ class ProviderConfiguration(BaseModel):
|
||||
:return:
|
||||
"""
|
||||
# get provider model
|
||||
provider_model_record = db.session.query(ProviderModel) \
|
||||
provider_model_record = (
|
||||
db.session.query(ProviderModel)
|
||||
.filter(
|
||||
ProviderModel.tenant_id == self.tenant_id,
|
||||
ProviderModel.provider_name == self.provider.provider,
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type.to_origin_model_type()
|
||||
).first()
|
||||
ProviderModel.tenant_id == self.tenant_id,
|
||||
ProviderModel.provider_name == self.provider.provider,
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.model_credential_schema.credential_form_schemas
|
||||
if self.provider.model_credential_schema else []
|
||||
if self.provider.model_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
if provider_model_record:
|
||||
try:
|
||||
original_credentials = json.loads(
|
||||
provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
|
||||
original_credentials = (
|
||||
json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
|
||||
)
|
||||
except JSONDecodeError:
|
||||
original_credentials = {}
|
||||
|
||||
@@ -350,10 +368,7 @@ class ProviderConfiguration(BaseModel):
|
||||
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
||||
|
||||
credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=self.provider.provider,
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
|
||||
for key, value in credentials.items():
|
||||
@@ -388,7 +403,7 @@ class ProviderConfiguration(BaseModel):
|
||||
model_name=model,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
encrypted_config=json.dumps(credentials),
|
||||
is_valid=True
|
||||
is_valid=True,
|
||||
)
|
||||
db.session.add(provider_model_record)
|
||||
db.session.commit()
|
||||
@@ -396,7 +411,7 @@ class ProviderConfiguration(BaseModel):
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_model_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.MODEL
|
||||
cache_type=ProviderCredentialsCacheType.MODEL,
|
||||
)
|
||||
|
||||
provider_model_credentials_cache.delete()
|
||||
@@ -409,13 +424,16 @@ class ProviderConfiguration(BaseModel):
|
||||
:return:
|
||||
"""
|
||||
# get provider model
|
||||
provider_model_record = db.session.query(ProviderModel) \
|
||||
provider_model_record = (
|
||||
db.session.query(ProviderModel)
|
||||
.filter(
|
||||
ProviderModel.tenant_id == self.tenant_id,
|
||||
ProviderModel.provider_name == self.provider.provider,
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type.to_origin_model_type()
|
||||
).first()
|
||||
ProviderModel.tenant_id == self.tenant_id,
|
||||
ProviderModel.provider_name == self.provider.provider,
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# delete provider model
|
||||
if provider_model_record:
|
||||
@@ -425,7 +443,7 @@ class ProviderConfiguration(BaseModel):
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_model_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.MODEL
|
||||
cache_type=ProviderCredentialsCacheType.MODEL,
|
||||
)
|
||||
|
||||
provider_model_credentials_cache.delete()
|
||||
@@ -437,13 +455,16 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_setting = db.session.query(ProviderModelSetting) \
|
||||
model_setting = (
|
||||
db.session.query(ProviderModelSetting)
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model
|
||||
).first()
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if model_setting:
|
||||
model_setting.enabled = True
|
||||
@@ -455,7 +476,7 @@ class ProviderConfiguration(BaseModel):
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
enabled=True
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
@@ -469,13 +490,16 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_setting = db.session.query(ProviderModelSetting) \
|
||||
model_setting = (
|
||||
db.session.query(ProviderModelSetting)
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model
|
||||
).first()
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if model_setting:
|
||||
model_setting.enabled = False
|
||||
@@ -487,7 +511,7 @@ class ProviderConfiguration(BaseModel):
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
enabled=False
|
||||
enabled=False,
|
||||
)
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
@@ -501,13 +525,16 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
return db.session.query(ProviderModelSetting) \
|
||||
return (
|
||||
db.session.query(ProviderModelSetting)
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model
|
||||
).first()
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
||||
"""
|
||||
@@ -516,24 +543,30 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \
|
||||
load_balancing_config_count = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model
|
||||
).count()
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
if load_balancing_config_count <= 1:
|
||||
raise ValueError('Model load balancing configuration must be more than 1.')
|
||||
raise ValueError("Model load balancing configuration must be more than 1.")
|
||||
|
||||
model_setting = db.session.query(ProviderModelSetting) \
|
||||
model_setting = (
|
||||
db.session.query(ProviderModelSetting)
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model
|
||||
).first()
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if model_setting:
|
||||
model_setting.load_balancing_enabled = True
|
||||
@@ -545,7 +578,7 @@ class ProviderConfiguration(BaseModel):
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
load_balancing_enabled=True
|
||||
load_balancing_enabled=True,
|
||||
)
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
@@ -559,13 +592,16 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_setting = db.session.query(ProviderModelSetting) \
|
||||
model_setting = (
|
||||
db.session.query(ProviderModelSetting)
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model
|
||||
).first()
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if model_setting:
|
||||
model_setting.load_balancing_enabled = False
|
||||
@@ -577,7 +613,7 @@ class ProviderConfiguration(BaseModel):
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
load_balancing_enabled=False
|
||||
load_balancing_enabled=False,
|
||||
)
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
@@ -617,11 +653,14 @@ class ProviderConfiguration(BaseModel):
|
||||
return
|
||||
|
||||
# get preferred provider
|
||||
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
|
||||
preferred_model_provider = (
|
||||
db.session.query(TenantPreferredModelProvider)
|
||||
.filter(
|
||||
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
||||
TenantPreferredModelProvider.provider_name == self.provider.provider
|
||||
).first()
|
||||
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
||||
TenantPreferredModelProvider.provider_name == self.provider.provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if preferred_model_provider:
|
||||
preferred_model_provider.preferred_provider_type = provider_type.value
|
||||
@@ -629,7 +668,7 @@ class ProviderConfiguration(BaseModel):
|
||||
preferred_model_provider = TenantPreferredModelProvider(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
preferred_provider_type=provider_type.value
|
||||
preferred_provider_type=provider_type.value,
|
||||
)
|
||||
db.session.add(preferred_model_provider)
|
||||
|
||||
@@ -658,9 +697,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:return:
|
||||
"""
|
||||
# Get provider credential secret variables
|
||||
credential_secret_variables = self.extract_secret_variables(
|
||||
credential_form_schemas
|
||||
)
|
||||
credential_secret_variables = self.extract_secret_variables(credential_form_schemas)
|
||||
|
||||
# Obfuscate provider credentials
|
||||
copy_credentials = credentials.copy()
|
||||
@@ -670,9 +707,9 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return copy_credentials
|
||||
|
||||
def get_provider_model(self, model_type: ModelType,
|
||||
model: str,
|
||||
only_active: bool = False) -> Optional[ModelWithProviderEntity]:
|
||||
def get_provider_model(
|
||||
self, model_type: ModelType, model: str, only_active: bool = False
|
||||
) -> Optional[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get provider model.
|
||||
:param model_type: model type
|
||||
@@ -688,8 +725,9 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return None
|
||||
|
||||
def get_provider_models(self, model_type: Optional[ModelType] = None,
|
||||
only_active: bool = False) -> list[ModelWithProviderEntity]:
|
||||
def get_provider_models(
|
||||
self, model_type: Optional[ModelType] = None, only_active: bool = False
|
||||
) -> list[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get provider models.
|
||||
:param model_type: model type
|
||||
@@ -711,15 +749,11 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
if self.using_provider_type == ProviderType.SYSTEM:
|
||||
provider_models = self._get_system_provider_models(
|
||||
model_types=model_types,
|
||||
provider_instance=provider_instance,
|
||||
model_setting_map=model_setting_map
|
||||
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
|
||||
)
|
||||
else:
|
||||
provider_models = self._get_custom_provider_models(
|
||||
model_types=model_types,
|
||||
provider_instance=provider_instance,
|
||||
model_setting_map=model_setting_map
|
||||
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
|
||||
)
|
||||
|
||||
if only_active:
|
||||
@@ -728,11 +762,12 @@ class ProviderConfiguration(BaseModel):
|
||||
# resort provider_models
|
||||
return sorted(provider_models, key=lambda x: x.model_type.value)
|
||||
|
||||
def _get_system_provider_models(self,
|
||||
model_types: list[ModelType],
|
||||
provider_instance: ModelProvider,
|
||||
model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
|
||||
-> list[ModelWithProviderEntity]:
|
||||
def _get_system_provider_models(
|
||||
self,
|
||||
model_types: list[ModelType],
|
||||
provider_instance: ModelProvider,
|
||||
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
|
||||
) -> list[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get system provider models.
|
||||
|
||||
@@ -760,7 +795,7 @@ class ProviderConfiguration(BaseModel):
|
||||
model_properties=m.model_properties,
|
||||
deprecated=m.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=status
|
||||
status=status,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -783,23 +818,20 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
if should_use_custom_model:
|
||||
if original_provider_configurate_methods[self.provider.provider] == [
|
||||
ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||
ConfigurateMethod.CUSTOMIZABLE_MODEL
|
||||
]:
|
||||
# only customizable model
|
||||
for restrict_model in restrict_models:
|
||||
copy_credentials = self.system_configuration.credentials.copy()
|
||||
if restrict_model.base_model_name:
|
||||
copy_credentials['base_model_name'] = restrict_model.base_model_name
|
||||
copy_credentials["base_model_name"] = restrict_model.base_model_name
|
||||
|
||||
try:
|
||||
custom_model_schema = (
|
||||
provider_instance.get_model_instance(restrict_model.model_type)
|
||||
.get_customizable_model_schema_from_credentials(
|
||||
restrict_model.model,
|
||||
copy_credentials
|
||||
)
|
||||
)
|
||||
custom_model_schema = provider_instance.get_model_instance(
|
||||
restrict_model.model_type
|
||||
).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
|
||||
except Exception as ex:
|
||||
logger.warning(f'get custom model schema failed, {ex}')
|
||||
logger.warning(f"get custom model schema failed, {ex}")
|
||||
continue
|
||||
|
||||
if not custom_model_schema:
|
||||
@@ -809,8 +841,10 @@ class ProviderConfiguration(BaseModel):
|
||||
continue
|
||||
|
||||
status = ModelStatus.ACTIVE
|
||||
if (custom_model_schema.model_type in model_setting_map
|
||||
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
|
||||
if (
|
||||
custom_model_schema.model_type in model_setting_map
|
||||
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
|
||||
):
|
||||
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
|
||||
if model_setting.enabled is False:
|
||||
status = ModelStatus.DISABLED
|
||||
@@ -825,7 +859,7 @@ class ProviderConfiguration(BaseModel):
|
||||
model_properties=custom_model_schema.model_properties,
|
||||
deprecated=custom_model_schema.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=status
|
||||
status=status,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -839,11 +873,12 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return provider_models
|
||||
|
||||
def _get_custom_provider_models(self,
|
||||
model_types: list[ModelType],
|
||||
provider_instance: ModelProvider,
|
||||
model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
|
||||
-> list[ModelWithProviderEntity]:
|
||||
def _get_custom_provider_models(
|
||||
self,
|
||||
model_types: list[ModelType],
|
||||
provider_instance: ModelProvider,
|
||||
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
|
||||
) -> list[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get custom provider models.
|
||||
|
||||
@@ -885,7 +920,7 @@ class ProviderConfiguration(BaseModel):
|
||||
deprecated=m.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=status,
|
||||
load_balancing_enabled=load_balancing_enabled
|
||||
load_balancing_enabled=load_balancing_enabled,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -895,15 +930,13 @@ class ProviderConfiguration(BaseModel):
|
||||
continue
|
||||
|
||||
try:
|
||||
custom_model_schema = (
|
||||
provider_instance.get_model_instance(model_configuration.model_type)
|
||||
.get_customizable_model_schema_from_credentials(
|
||||
model_configuration.model,
|
||||
model_configuration.credentials
|
||||
)
|
||||
custom_model_schema = provider_instance.get_model_instance(
|
||||
model_configuration.model_type
|
||||
).get_customizable_model_schema_from_credentials(
|
||||
model_configuration.model, model_configuration.credentials
|
||||
)
|
||||
except Exception as ex:
|
||||
logger.warning(f'get custom model schema failed, {ex}')
|
||||
logger.warning(f"get custom model schema failed, {ex}")
|
||||
continue
|
||||
|
||||
if not custom_model_schema:
|
||||
@@ -911,8 +944,10 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
status = ModelStatus.ACTIVE
|
||||
load_balancing_enabled = False
|
||||
if (custom_model_schema.model_type in model_setting_map
|
||||
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
|
||||
if (
|
||||
custom_model_schema.model_type in model_setting_map
|
||||
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
|
||||
):
|
||||
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
|
||||
if model_setting.enabled is False:
|
||||
status = ModelStatus.DISABLED
|
||||
@@ -931,7 +966,7 @@ class ProviderConfiguration(BaseModel):
|
||||
deprecated=custom_model_schema.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=status,
|
||||
load_balancing_enabled=load_balancing_enabled
|
||||
load_balancing_enabled=load_balancing_enabled,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -942,17 +977,16 @@ class ProviderConfigurations(BaseModel):
|
||||
"""
|
||||
Model class for provider configuration dict.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
configurations: dict[str, ProviderConfiguration] = {}
|
||||
|
||||
def __init__(self, tenant_id: str):
|
||||
super().__init__(tenant_id=tenant_id)
|
||||
|
||||
def get_models(self,
|
||||
provider: Optional[str] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
only_active: bool = False) \
|
||||
-> list[ModelWithProviderEntity]:
|
||||
def get_models(
|
||||
self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
|
||||
) -> list[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get available models.
|
||||
|
||||
@@ -1019,10 +1053,10 @@ class ProviderModelBundle(BaseModel):
|
||||
"""
|
||||
Provider model bundle.
|
||||
"""
|
||||
|
||||
configuration: ProviderConfiguration
|
||||
provider_instance: ModelProvider
|
||||
model_type_instance: AIModel
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True,
|
||||
protected_namespaces=())
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
|
||||
|
Reference in New Issue
Block a user