feat: mypy for all type check (#10921)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
@@ -15,6 +15,7 @@ from core.entities.provider_entities import (
|
||||
ModelLoadBalancingConfiguration,
|
||||
ModelSettings,
|
||||
QuotaConfiguration,
|
||||
QuotaUnit,
|
||||
SystemConfiguration,
|
||||
)
|
||||
from core.helper import encrypter
|
||||
@@ -116,8 +117,8 @@ class ProviderManager:
|
||||
for provider_entity in provider_entities:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
|
||||
include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET),
|
||||
exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET),
|
||||
data=provider_entity,
|
||||
name_func=lambda x: x.provider,
|
||||
):
|
||||
@@ -490,12 +491,13 @@ class ProviderManager:
|
||||
# Init trial provider records if not exists
|
||||
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
|
||||
try:
|
||||
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
|
||||
provider_record = Provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=ProviderQuotaType.TRIAL.value,
|
||||
quota_limit=quota.quota_limit,
|
||||
quota_limit=quota.quota_limit, # type: ignore
|
||||
quota_used=0,
|
||||
is_valid=True,
|
||||
)
|
||||
@@ -589,7 +591,9 @@ class ProviderManager:
|
||||
if variable in provider_credentials:
|
||||
try:
|
||||
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa
|
||||
provider_credentials.get(variable) or "", # type: ignore
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
@@ -671,13 +675,9 @@ class ProviderManager:
|
||||
# Get hosting configuration
|
||||
hosting_configuration = ext_hosting_provider.hosting_configuration
|
||||
|
||||
if (
|
||||
provider_entity.provider not in hosting_configuration.provider_map
|
||||
or not hosting_configuration.provider_map.get(provider_entity.provider).enabled
|
||||
):
|
||||
return SystemConfiguration(enabled=False)
|
||||
|
||||
provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider)
|
||||
if provider_hosting_configuration is None or not provider_hosting_configuration.enabled:
|
||||
return SystemConfiguration(enabled=False)
|
||||
|
||||
# Convert provider_records to dict
|
||||
quota_type_to_provider_records_dict = {}
|
||||
@@ -688,14 +688,13 @@ class ProviderManager:
|
||||
quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
|
||||
provider_record
|
||||
)
|
||||
|
||||
quota_configurations = []
|
||||
for provider_quota in provider_hosting_configuration.quotas:
|
||||
if provider_quota.quota_type not in quota_type_to_provider_records_dict:
|
||||
if provider_quota.quota_type == ProviderQuotaType.FREE:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=0,
|
||||
quota_limit=0,
|
||||
is_valid=False,
|
||||
@@ -708,7 +707,7 @@ class ProviderManager:
|
||||
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=provider_record.quota_used,
|
||||
quota_limit=provider_record.quota_limit,
|
||||
is_valid=provider_record.quota_limit > provider_record.quota_used
|
||||
@@ -725,12 +724,12 @@ class ProviderManager:
|
||||
|
||||
current_using_credentials = provider_hosting_configuration.credentials
|
||||
if current_quota_type == ProviderQuotaType.FREE:
|
||||
provider_record = quota_type_to_provider_records_dict.get(current_quota_type)
|
||||
provider_record_quota_free = quota_type_to_provider_records_dict.get(current_quota_type)
|
||||
|
||||
if provider_record:
|
||||
if provider_record_quota_free:
|
||||
provider_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
identity_id=provider_record.id,
|
||||
identity_id=provider_record_quota_free.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
|
||||
@@ -763,7 +762,7 @@ class ProviderManager:
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
current_using_credentials = provider_credentials
|
||||
current_using_credentials = provider_credentials or {}
|
||||
|
||||
# cache provider credentials
|
||||
provider_credentials_cache.set(credentials=current_using_credentials)
|
||||
@@ -842,7 +841,7 @@ class ProviderManager:
|
||||
else []
|
||||
)
|
||||
|
||||
model_settings = []
|
||||
model_settings: list[ModelSettings] = []
|
||||
if not provider_model_settings:
|
||||
return model_settings
|
||||
|
||||
|
Reference in New Issue
Block a user