feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -1,7 +1,7 @@
import json
from collections import defaultdict
from json import JSONDecodeError
from typing import Optional
from typing import Optional, cast
from sqlalchemy.exc import IntegrityError
@@ -15,6 +15,7 @@ from core.entities.provider_entities import (
ModelLoadBalancingConfiguration,
ModelSettings,
QuotaConfiguration,
QuotaUnit,
SystemConfiguration,
)
from core.helper import encrypter
@@ -116,8 +117,8 @@ class ProviderManager:
for provider_entity in provider_entities:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET),
data=provider_entity,
name_func=lambda x: x.provider,
):
@@ -490,12 +491,13 @@ class ProviderManager:
# Init trial provider records if not exists
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
try:
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
provider_record = Provider(
tenant_id=tenant_id,
provider_name=provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=ProviderQuotaType.TRIAL.value,
quota_limit=quota.quota_limit,
quota_limit=quota.quota_limit, # type: ignore
quota_used=0,
is_valid=True,
)
@@ -589,7 +591,9 @@ class ProviderManager:
if variable in provider_credentials:
try:
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa
provider_credentials.get(variable) or "", # type: ignore
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)
except ValueError:
pass
@@ -671,13 +675,9 @@ class ProviderManager:
# Get hosting configuration
hosting_configuration = ext_hosting_provider.hosting_configuration
if (
provider_entity.provider not in hosting_configuration.provider_map
or not hosting_configuration.provider_map.get(provider_entity.provider).enabled
):
return SystemConfiguration(enabled=False)
provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider)
if provider_hosting_configuration is None or not provider_hosting_configuration.enabled:
return SystemConfiguration(enabled=False)
# Convert provider_records to dict
quota_type_to_provider_records_dict = {}
@@ -688,14 +688,13 @@ class ProviderManager:
quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
provider_record
)
quota_configurations = []
for provider_quota in provider_hosting_configuration.quotas:
if provider_quota.quota_type not in quota_type_to_provider_records_dict:
if provider_quota.quota_type == ProviderQuotaType.FREE:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=0,
quota_limit=0,
is_valid=False,
@@ -708,7 +707,7 @@ class ProviderManager:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used
@@ -725,12 +724,12 @@ class ProviderManager:
current_using_credentials = provider_hosting_configuration.credentials
if current_quota_type == ProviderQuotaType.FREE:
provider_record = quota_type_to_provider_records_dict.get(current_quota_type)
provider_record_quota_free = quota_type_to_provider_records_dict.get(current_quota_type)
if provider_record:
if provider_record_quota_free:
provider_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id,
identity_id=provider_record.id,
identity_id=provider_record_quota_free.id,
cache_type=ProviderCredentialsCacheType.PROVIDER,
)
@@ -763,7 +762,7 @@ class ProviderManager:
except ValueError:
pass
current_using_credentials = provider_credentials
current_using_credentials = provider_credentials or {}
# cache provider credentials
provider_credentials_cache.set(credentials=current_using_credentials)
@@ -842,7 +841,7 @@ class ProviderManager:
else []
)
model_settings = []
model_settings: list[ModelSettings] = []
if not provider_model_settings:
return model_settings