chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -90,8 +90,7 @@ class ProviderManager:
|
||||
|
||||
# Initialize trial provider records if not exist
|
||||
provider_name_to_provider_records_dict = self._init_trial_provider_records(
|
||||
tenant_id,
|
||||
provider_name_to_provider_records_dict
|
||||
tenant_id, provider_name_to_provider_records_dict
|
||||
)
|
||||
|
||||
# Get all provider model records of the workspace
|
||||
@@ -107,22 +106,20 @@ class ProviderManager:
|
||||
provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id)
|
||||
|
||||
# Get All load balancing configs
|
||||
provider_name_to_provider_load_balancing_model_configs_dict \
|
||||
= self._get_all_provider_load_balancing_configs(tenant_id)
|
||||
|
||||
provider_configurations = ProviderConfigurations(
|
||||
tenant_id=tenant_id
|
||||
provider_name_to_provider_load_balancing_model_configs_dict = self._get_all_provider_load_balancing_configs(
|
||||
tenant_id
|
||||
)
|
||||
|
||||
provider_configurations = ProviderConfigurations(tenant_id=tenant_id)
|
||||
|
||||
# Construct ProviderConfiguration objects for each provider
|
||||
for provider_entity in provider_entities:
|
||||
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
|
||||
data=provider_entity,
|
||||
name_func=lambda x: x.provider,
|
||||
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
|
||||
data=provider_entity,
|
||||
name_func=lambda x: x.provider,
|
||||
):
|
||||
continue
|
||||
|
||||
@@ -132,18 +129,11 @@ class ProviderManager:
|
||||
|
||||
# Convert to custom configuration
|
||||
custom_configuration = self._to_custom_configuration(
|
||||
tenant_id,
|
||||
provider_entity,
|
||||
provider_records,
|
||||
provider_model_records
|
||||
tenant_id, provider_entity, provider_records, provider_model_records
|
||||
)
|
||||
|
||||
# Convert to system configuration
|
||||
system_configuration = self._to_system_configuration(
|
||||
tenant_id,
|
||||
provider_entity,
|
||||
provider_records
|
||||
)
|
||||
system_configuration = self._to_system_configuration(tenant_id, provider_entity, provider_records)
|
||||
|
||||
# Get preferred provider type
|
||||
preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get(provider_name)
|
||||
@@ -173,14 +163,15 @@ class ProviderManager:
|
||||
provider_model_settings = provider_name_to_provider_model_settings_dict.get(provider_name)
|
||||
|
||||
# Get provider load balancing configs
|
||||
provider_load_balancing_configs \
|
||||
= provider_name_to_provider_load_balancing_model_configs_dict.get(provider_name)
|
||||
provider_load_balancing_configs = provider_name_to_provider_load_balancing_model_configs_dict.get(
|
||||
provider_name
|
||||
)
|
||||
|
||||
# Convert to model settings
|
||||
model_settings = self._to_model_settings(
|
||||
provider_entity=provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=provider_load_balancing_configs
|
||||
load_balancing_model_configs=provider_load_balancing_configs,
|
||||
)
|
||||
|
||||
provider_configuration = ProviderConfiguration(
|
||||
@@ -190,7 +181,7 @@ class ProviderManager:
|
||||
using_provider_type=using_provider_type,
|
||||
system_configuration=system_configuration,
|
||||
custom_configuration=custom_configuration,
|
||||
model_settings=model_settings
|
||||
model_settings=model_settings,
|
||||
)
|
||||
|
||||
provider_configurations[provider_name] = provider_configuration
|
||||
@@ -219,7 +210,7 @@ class ProviderManager:
|
||||
return ProviderModelBundle(
|
||||
configuration=provider_configuration,
|
||||
provider_instance=provider_instance,
|
||||
model_type_instance=model_type_instance
|
||||
model_type_instance=model_type_instance,
|
||||
)
|
||||
|
||||
def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[DefaultModelEntity]:
|
||||
@@ -231,11 +222,14 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
# Get the corresponding TenantDefaultModel record
|
||||
default_model = db.session.query(TenantDefaultModel) \
|
||||
default_model = (
|
||||
db.session.query(TenantDefaultModel)
|
||||
.filter(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.to_origin_model_type()
|
||||
).first()
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# If it does not exist, get the first available provider model from get_configurations
|
||||
# and update the TenantDefaultModel record
|
||||
@@ -244,20 +238,18 @@ class ProviderManager:
|
||||
provider_configurations = self.get_configurations(tenant_id)
|
||||
|
||||
# get available models from provider_configurations
|
||||
available_models = provider_configurations.get_models(
|
||||
model_type=model_type,
|
||||
only_active=True
|
||||
)
|
||||
available_models = provider_configurations.get_models(model_type=model_type, only_active=True)
|
||||
|
||||
if available_models:
|
||||
available_model = next((model for model in available_models if model.model == "gpt-4"),
|
||||
available_models[0])
|
||||
available_model = next(
|
||||
(model for model in available_models if model.model == "gpt-4"), available_models[0]
|
||||
)
|
||||
|
||||
default_model = TenantDefaultModel(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
provider_name=available_model.provider.provider,
|
||||
model_name=available_model.model
|
||||
model_name=available_model.model,
|
||||
)
|
||||
db.session.add(default_model)
|
||||
db.session.commit()
|
||||
@@ -276,8 +268,8 @@ class ProviderManager:
|
||||
label=provider_schema.label,
|
||||
icon_small=provider_schema.icon_small,
|
||||
icon_large=provider_schema.icon_large,
|
||||
supported_model_types=provider_schema.supported_model_types
|
||||
)
|
||||
supported_model_types=provider_schema.supported_model_types,
|
||||
),
|
||||
)
|
||||
|
||||
def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
|
||||
@@ -291,15 +283,13 @@ class ProviderManager:
|
||||
provider_configurations = self.get_configurations(tenant_id)
|
||||
|
||||
# get available models from provider_configurations
|
||||
all_models = provider_configurations.get_models(
|
||||
model_type=model_type,
|
||||
only_active=False
|
||||
)
|
||||
all_models = provider_configurations.get_models(model_type=model_type, only_active=False)
|
||||
|
||||
return all_models[0].provider.provider, all_models[0].model
|
||||
|
||||
def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \
|
||||
-> TenantDefaultModel:
|
||||
def update_default_model_record(
|
||||
self, tenant_id: str, model_type: ModelType, provider: str, model: str
|
||||
) -> TenantDefaultModel:
|
||||
"""
|
||||
Update default model record.
|
||||
|
||||
@@ -314,10 +304,7 @@ class ProviderManager:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# get available models from provider_configurations
|
||||
available_models = provider_configurations.get_models(
|
||||
model_type=model_type,
|
||||
only_active=True
|
||||
)
|
||||
available_models = provider_configurations.get_models(model_type=model_type, only_active=True)
|
||||
|
||||
# check if the model is exist in available models
|
||||
model_names = [model.model for model in available_models]
|
||||
@@ -325,11 +312,14 @@ class ProviderManager:
|
||||
raise ValueError(f"Model {model} does not exist.")
|
||||
|
||||
# Get the list of available models from get_configurations and check if it is LLM
|
||||
default_model = db.session.query(TenantDefaultModel) \
|
||||
default_model = (
|
||||
db.session.query(TenantDefaultModel)
|
||||
.filter(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.to_origin_model_type()
|
||||
).first()
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# create or update TenantDefaultModel record
|
||||
if default_model:
|
||||
@@ -358,11 +348,7 @@ class ProviderManager:
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
providers = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.is_valid == True
|
||||
).all()
|
||||
providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all()
|
||||
|
||||
provider_name_to_provider_records_dict = defaultdict(list)
|
||||
for provider in providers:
|
||||
@@ -379,11 +365,11 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
# Get all provider model records of the workspace
|
||||
provider_models = db.session.query(ProviderModel) \
|
||||
.filter(
|
||||
ProviderModel.tenant_id == tenant_id,
|
||||
ProviderModel.is_valid == True
|
||||
).all()
|
||||
provider_models = (
|
||||
db.session.query(ProviderModel)
|
||||
.filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
|
||||
.all()
|
||||
)
|
||||
|
||||
provider_name_to_provider_model_records_dict = defaultdict(list)
|
||||
for provider_model in provider_models:
|
||||
@@ -399,10 +385,11 @@ class ProviderManager:
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
preferred_provider_types = db.session.query(TenantPreferredModelProvider) \
|
||||
.filter(
|
||||
TenantPreferredModelProvider.tenant_id == tenant_id
|
||||
).all()
|
||||
preferred_provider_types = (
|
||||
db.session.query(TenantPreferredModelProvider)
|
||||
.filter(TenantPreferredModelProvider.tenant_id == tenant_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
provider_name_to_preferred_provider_type_records_dict = {
|
||||
preferred_provider_type.provider_name: preferred_provider_type
|
||||
@@ -419,15 +406,17 @@ class ProviderManager:
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
provider_model_settings = db.session.query(ProviderModelSetting) \
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == tenant_id
|
||||
).all()
|
||||
provider_model_settings = (
|
||||
db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
provider_name_to_provider_model_settings_dict = defaultdict(list)
|
||||
for provider_model_setting in provider_model_settings:
|
||||
(provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name]
|
||||
.append(provider_model_setting))
|
||||
(
|
||||
provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append(
|
||||
provider_model_setting
|
||||
)
|
||||
)
|
||||
|
||||
return provider_name_to_provider_model_settings_dict
|
||||
|
||||
@@ -445,27 +434,30 @@ class ProviderManager:
|
||||
model_load_balancing_enabled = FeatureService.get_features(tenant_id).model_load_balancing_enabled
|
||||
redis_client.setex(cache_key, 120, str(model_load_balancing_enabled))
|
||||
else:
|
||||
cache_result = cache_result.decode('utf-8')
|
||||
model_load_balancing_enabled = cache_result == 'True'
|
||||
cache_result = cache_result.decode("utf-8")
|
||||
model_load_balancing_enabled = cache_result == "True"
|
||||
|
||||
if not model_load_balancing_enabled:
|
||||
return {}
|
||||
|
||||
provider_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id
|
||||
).all()
|
||||
provider_load_balancing_configs = (
|
||||
db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
|
||||
for provider_load_balancing_config in provider_load_balancing_configs:
|
||||
(provider_name_to_provider_load_balancing_model_configs_dict[provider_load_balancing_config.provider_name]
|
||||
.append(provider_load_balancing_config))
|
||||
(
|
||||
provider_name_to_provider_load_balancing_model_configs_dict[
|
||||
provider_load_balancing_config.provider_name
|
||||
].append(provider_load_balancing_config)
|
||||
)
|
||||
|
||||
return provider_name_to_provider_load_balancing_model_configs_dict
|
||||
|
||||
@staticmethod
|
||||
def _init_trial_provider_records(tenant_id: str,
|
||||
provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]:
|
||||
def _init_trial_provider_records(
|
||||
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list]
|
||||
) -> dict[str, list]:
|
||||
"""
|
||||
Initialize trial provider records if not exists.
|
||||
|
||||
@@ -489,8 +481,9 @@ class ProviderManager:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM.value:
|
||||
continue
|
||||
|
||||
provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] \
|
||||
= provider_record
|
||||
provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
|
||||
provider_record
|
||||
)
|
||||
|
||||
for quota in configuration.quotas:
|
||||
if quota.quota_type == ProviderQuotaType.TRIAL:
|
||||
@@ -504,19 +497,22 @@ class ProviderManager:
|
||||
quota_type=ProviderQuotaType.TRIAL.value,
|
||||
quota_limit=quota.quota_limit,
|
||||
quota_used=0,
|
||||
is_valid=True
|
||||
is_valid=True,
|
||||
)
|
||||
db.session.add(provider_record)
|
||||
db.session.commit()
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
provider_record = db.session.query(Provider) \
|
||||
provider_record = (
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL.value
|
||||
).first()
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL.value,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider_record and not provider_record.is_valid:
|
||||
provider_record.is_valid = True
|
||||
@@ -526,11 +522,13 @@ class ProviderManager:
|
||||
|
||||
return provider_name_to_provider_records_dict
|
||||
|
||||
def _to_custom_configuration(self,
|
||||
tenant_id: str,
|
||||
provider_entity: ProviderEntity,
|
||||
provider_records: list[Provider],
|
||||
provider_model_records: list[ProviderModel]) -> CustomConfiguration:
|
||||
def _to_custom_configuration(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider_entity: ProviderEntity,
|
||||
provider_records: list[Provider],
|
||||
provider_model_records: list[ProviderModel],
|
||||
) -> CustomConfiguration:
|
||||
"""
|
||||
Convert to custom configuration.
|
||||
|
||||
@@ -543,7 +541,8 @@ class ProviderManager:
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.provider_credential_schema.credential_form_schemas
|
||||
if provider_entity.provider_credential_schema else []
|
||||
if provider_entity.provider_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
# Get custom provider record
|
||||
@@ -563,7 +562,7 @@ class ProviderManager:
|
||||
provider_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
identity_id=custom_provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
|
||||
# Get cached provider credentials
|
||||
@@ -572,11 +571,11 @@ class ProviderManager:
|
||||
if not cached_provider_credentials:
|
||||
try:
|
||||
# fix origin data
|
||||
if (custom_provider_record.encrypted_config
|
||||
and not custom_provider_record.encrypted_config.startswith("{")):
|
||||
provider_credentials = {
|
||||
"openai_api_key": custom_provider_record.encrypted_config
|
||||
}
|
||||
if (
|
||||
custom_provider_record.encrypted_config
|
||||
and not custom_provider_record.encrypted_config.startswith("{")
|
||||
):
|
||||
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
|
||||
else:
|
||||
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
@@ -590,28 +589,23 @@ class ProviderManager:
|
||||
if variable in provider_credentials:
|
||||
try:
|
||||
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_credentials.get(variable),
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa
|
||||
provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# cache provider credentials
|
||||
provider_credentials_cache.set(
|
||||
credentials=provider_credentials
|
||||
)
|
||||
provider_credentials_cache.set(credentials=provider_credentials)
|
||||
else:
|
||||
provider_credentials = cached_provider_credentials
|
||||
|
||||
custom_provider_configuration = CustomProviderConfiguration(
|
||||
credentials=provider_credentials
|
||||
)
|
||||
custom_provider_configuration = CustomProviderConfiguration(credentials=provider_credentials)
|
||||
|
||||
# Get provider model credential secret variables
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.model_credential_schema.credential_form_schemas
|
||||
if provider_entity.model_credential_schema else []
|
||||
if provider_entity.model_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
# Get custom provider model credentials
|
||||
@@ -621,9 +615,7 @@ class ProviderManager:
|
||||
continue
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
identity_id=provider_model_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.MODEL
|
||||
tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL
|
||||
)
|
||||
|
||||
# Get cached provider model credentials
|
||||
@@ -645,15 +637,13 @@ class ProviderManager:
|
||||
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_model_credentials.get(variable),
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa
|
||||
self.decoding_cipher_rsa,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# cache provider model credentials
|
||||
provider_model_credentials_cache.set(
|
||||
credentials=provider_model_credentials
|
||||
)
|
||||
provider_model_credentials_cache.set(credentials=provider_model_credentials)
|
||||
else:
|
||||
provider_model_credentials = cached_provider_model_credentials
|
||||
|
||||
@@ -661,19 +651,15 @@ class ProviderManager:
|
||||
CustomModelConfiguration(
|
||||
model=provider_model_record.model_name,
|
||||
model_type=ModelType.value_of(provider_model_record.model_type),
|
||||
credentials=provider_model_credentials
|
||||
credentials=provider_model_credentials,
|
||||
)
|
||||
)
|
||||
|
||||
return CustomConfiguration(
|
||||
provider=custom_provider_configuration,
|
||||
models=custom_model_configurations
|
||||
)
|
||||
return CustomConfiguration(provider=custom_provider_configuration, models=custom_model_configurations)
|
||||
|
||||
def _to_system_configuration(self,
|
||||
tenant_id: str,
|
||||
provider_entity: ProviderEntity,
|
||||
provider_records: list[Provider]) -> SystemConfiguration:
|
||||
def _to_system_configuration(
|
||||
self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider]
|
||||
) -> SystemConfiguration:
|
||||
"""
|
||||
Convert to system configuration.
|
||||
|
||||
@@ -685,11 +671,11 @@ class ProviderManager:
|
||||
# Get hosting configuration
|
||||
hosting_configuration = ext_hosting_provider.hosting_configuration
|
||||
|
||||
if provider_entity.provider not in hosting_configuration.provider_map \
|
||||
or not hosting_configuration.provider_map.get(provider_entity.provider).enabled:
|
||||
return SystemConfiguration(
|
||||
enabled=False
|
||||
)
|
||||
if (
|
||||
provider_entity.provider not in hosting_configuration.provider_map
|
||||
or not hosting_configuration.provider_map.get(provider_entity.provider).enabled
|
||||
):
|
||||
return SystemConfiguration(enabled=False)
|
||||
|
||||
provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider)
|
||||
|
||||
@@ -699,8 +685,9 @@ class ProviderManager:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM.value:
|
||||
continue
|
||||
|
||||
quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] \
|
||||
= provider_record
|
||||
quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
|
||||
provider_record
|
||||
)
|
||||
|
||||
quota_configurations = []
|
||||
for provider_quota in provider_hosting_configuration.quotas:
|
||||
@@ -712,7 +699,7 @@ class ProviderManager:
|
||||
quota_used=0,
|
||||
quota_limit=0,
|
||||
is_valid=False,
|
||||
restrict_models=provider_quota.restrict_models
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
else:
|
||||
continue
|
||||
@@ -724,16 +711,15 @@ class ProviderManager:
|
||||
quota_unit=provider_hosting_configuration.quota_unit,
|
||||
quota_used=provider_record.quota_used,
|
||||
quota_limit=provider_record.quota_limit,
|
||||
is_valid=provider_record.quota_limit > provider_record.quota_used or provider_record.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models
|
||||
is_valid=provider_record.quota_limit > provider_record.quota_used
|
||||
or provider_record.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
|
||||
quota_configurations.append(quota_configuration)
|
||||
|
||||
if len(quota_configurations) == 0:
|
||||
return SystemConfiguration(
|
||||
enabled=False
|
||||
)
|
||||
return SystemConfiguration(enabled=False)
|
||||
|
||||
current_quota_type = self._choice_current_using_quota_type(quota_configurations)
|
||||
|
||||
@@ -745,7 +731,7 @@ class ProviderManager:
|
||||
provider_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
identity_id=provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
|
||||
# Get cached provider credentials
|
||||
@@ -760,7 +746,8 @@ class ProviderManager:
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.provider_credential_schema.credential_form_schemas
|
||||
if provider_entity.provider_credential_schema else []
|
||||
if provider_entity.provider_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
# Get decoding rsa key and cipher for decrypting credentials
|
||||
@@ -771,9 +758,7 @@ class ProviderManager:
|
||||
if variable in provider_credentials:
|
||||
try:
|
||||
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_credentials.get(variable),
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa
|
||||
provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
@@ -781,9 +766,7 @@ class ProviderManager:
|
||||
current_using_credentials = provider_credentials
|
||||
|
||||
# cache provider credentials
|
||||
provider_credentials_cache.set(
|
||||
credentials=current_using_credentials
|
||||
)
|
||||
provider_credentials_cache.set(credentials=current_using_credentials)
|
||||
else:
|
||||
current_using_credentials = cached_provider_credentials
|
||||
else:
|
||||
@@ -794,7 +777,7 @@ class ProviderManager:
|
||||
enabled=True,
|
||||
current_quota_type=current_quota_type,
|
||||
quota_configurations=quota_configurations,
|
||||
credentials=current_using_credentials
|
||||
credentials=current_using_credentials,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -809,8 +792,7 @@ class ProviderManager:
|
||||
"""
|
||||
# convert to dict
|
||||
quota_type_to_quota_configuration_dict = {
|
||||
quota_configuration.quota_type: quota_configuration
|
||||
for quota_configuration in quota_configurations
|
||||
quota_configuration.quota_type: quota_configuration for quota_configuration in quota_configurations
|
||||
}
|
||||
|
||||
last_quota_configuration = None
|
||||
@@ -823,7 +805,7 @@ class ProviderManager:
|
||||
if last_quota_configuration:
|
||||
return last_quota_configuration.quota_type
|
||||
|
||||
raise ValueError('No quota type available')
|
||||
raise ValueError("No quota type available")
|
||||
|
||||
@staticmethod
|
||||
def _extract_secret_variables(credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
|
||||
@@ -840,10 +822,12 @@ class ProviderManager:
|
||||
|
||||
return secret_input_form_variables
|
||||
|
||||
def _to_model_settings(self, provider_entity: ProviderEntity,
|
||||
provider_model_settings: Optional[list[ProviderModelSetting]] = None,
|
||||
load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None) \
|
||||
-> list[ModelSettings]:
|
||||
def _to_model_settings(
|
||||
self,
|
||||
provider_entity: ProviderEntity,
|
||||
provider_model_settings: Optional[list[ProviderModelSetting]] = None,
|
||||
load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None,
|
||||
) -> list[ModelSettings]:
|
||||
"""
|
||||
Convert to model settings.
|
||||
:param provider_entity: provider entity
|
||||
@@ -854,7 +838,8 @@ class ProviderManager:
|
||||
# Get provider model credential secret variables
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.model_credential_schema.credential_form_schemas
|
||||
if provider_entity.model_credential_schema else []
|
||||
if provider_entity.model_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
model_settings = []
|
||||
@@ -865,24 +850,28 @@ class ProviderManager:
|
||||
load_balancing_configs = []
|
||||
if provider_model_setting.load_balancing_enabled and load_balancing_model_configs:
|
||||
for load_balancing_model_config in load_balancing_model_configs:
|
||||
if (load_balancing_model_config.model_name == provider_model_setting.model_name
|
||||
and load_balancing_model_config.model_type == provider_model_setting.model_type):
|
||||
if (
|
||||
load_balancing_model_config.model_name == provider_model_setting.model_name
|
||||
and load_balancing_model_config.model_type == provider_model_setting.model_type
|
||||
):
|
||||
if not load_balancing_model_config.enabled:
|
||||
continue
|
||||
|
||||
if not load_balancing_model_config.encrypted_config:
|
||||
if load_balancing_model_config.name == "__inherit__":
|
||||
load_balancing_configs.append(ModelLoadBalancingConfiguration(
|
||||
id=load_balancing_model_config.id,
|
||||
name=load_balancing_model_config.name,
|
||||
credentials={}
|
||||
))
|
||||
load_balancing_configs.append(
|
||||
ModelLoadBalancingConfiguration(
|
||||
id=load_balancing_model_config.id,
|
||||
name=load_balancing_model_config.name,
|
||||
credentials={},
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=load_balancing_model_config.tenant_id,
|
||||
identity_id=load_balancing_model_config.id,
|
||||
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
|
||||
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
|
||||
)
|
||||
|
||||
# Get cached provider model credentials
|
||||
@@ -897,7 +886,8 @@ class ProviderManager:
|
||||
# Get decoding rsa key and cipher for decrypting credentials
|
||||
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
||||
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(
|
||||
load_balancing_model_config.tenant_id)
|
||||
load_balancing_model_config.tenant_id
|
||||
)
|
||||
|
||||
for variable in model_credential_secret_variables:
|
||||
if variable in provider_model_credentials:
|
||||
@@ -905,30 +895,30 @@ class ProviderManager:
|
||||
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_model_credentials.get(variable),
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa
|
||||
self.decoding_cipher_rsa,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# cache provider model credentials
|
||||
provider_model_credentials_cache.set(
|
||||
credentials=provider_model_credentials
|
||||
)
|
||||
provider_model_credentials_cache.set(credentials=provider_model_credentials)
|
||||
else:
|
||||
provider_model_credentials = cached_provider_model_credentials
|
||||
|
||||
load_balancing_configs.append(ModelLoadBalancingConfiguration(
|
||||
id=load_balancing_model_config.id,
|
||||
name=load_balancing_model_config.name,
|
||||
credentials=provider_model_credentials
|
||||
))
|
||||
load_balancing_configs.append(
|
||||
ModelLoadBalancingConfiguration(
|
||||
id=load_balancing_model_config.id,
|
||||
name=load_balancing_model_config.name,
|
||||
credentials=provider_model_credentials,
|
||||
)
|
||||
)
|
||||
|
||||
model_settings.append(
|
||||
ModelSettings(
|
||||
model=provider_model_setting.model_name,
|
||||
model_type=ModelType.value_of(provider_model_setting.model_type),
|
||||
enabled=provider_model_setting.enabled,
|
||||
load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else []
|
||||
load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [],
|
||||
)
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user