feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -1,3 +1,4 @@
from collections.abc import Sequence
from enum import Enum
from typing import Optional
@@ -72,7 +73,7 @@ class DefaultModelProviderEntity(BaseModel):
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
supported_model_types: list[ModelType]
supported_model_types: Sequence[ModelType] = []
class DefaultModelEntity(BaseModel):

View File

@@ -40,7 +40,7 @@ from models.provider import (
logger = logging.getLogger(__name__)
original_provider_configurate_methods = {}
original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {}
class ProviderConfiguration(BaseModel):
@@ -99,7 +99,8 @@ class ProviderConfiguration(BaseModel):
continue
restrict_models = quota_configuration.restrict_models
if self.system_configuration.credentials is None:
return None
copy_credentials = self.system_configuration.credentials.copy()
if restrict_models:
for restrict_model in restrict_models:
@@ -124,7 +125,7 @@ class ProviderConfiguration(BaseModel):
return credentials
def get_system_configuration_status(self) -> SystemConfigurationStatus:
def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]:
"""
Get system configuration status.
:return:
@@ -136,6 +137,8 @@ class ProviderConfiguration(BaseModel):
current_quota_configuration = next(
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
)
if current_quota_configuration is None:
return None
return (
SystemConfigurationStatus.ACTIVE
@@ -150,7 +153,7 @@ class ProviderConfiguration(BaseModel):
"""
return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
def get_custom_credentials(self, obfuscated: bool = False):
"""
Get custom credentials.
@@ -172,7 +175,7 @@ class ProviderConfiguration(BaseModel):
else [],
)
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
def custom_credentials_validate(self, credentials: dict) -> tuple[Optional[Provider], dict]:
"""
Validate custom credentials.
:param credentials: provider credentials
@@ -324,7 +327,7 @@ class ProviderConfiguration(BaseModel):
def custom_model_credentials_validate(
self, model_type: ModelType, model: str, credentials: dict
) -> tuple[ProviderModel, dict]:
) -> tuple[Optional[ProviderModel], dict]:
"""
Validate custom model credentials.
@@ -740,10 +743,10 @@ class ProviderConfiguration(BaseModel):
if model_type:
model_types.append(model_type)
else:
model_types = provider_instance.get_provider_schema().supported_model_types
model_types = list(provider_instance.get_provider_schema().supported_model_types)
# Group model settings by model type and model
model_setting_map = defaultdict(dict)
model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
for model_setting in self.model_settings:
model_setting_map[model_setting.model_type][model_setting.model] = model_setting
@@ -822,54 +825,57 @@ class ProviderConfiguration(BaseModel):
]:
# only customizable model
for restrict_model in restrict_models:
copy_credentials = self.system_configuration.credentials.copy()
if restrict_model.base_model_name:
copy_credentials["base_model_name"] = restrict_model.base_model_name
if self.system_configuration.credentials is not None:
copy_credentials = self.system_configuration.credentials.copy()
if restrict_model.base_model_name:
copy_credentials["base_model_name"] = restrict_model.base_model_name
try:
custom_model_schema = provider_instance.get_model_instance(
restrict_model.model_type
).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
except Exception as ex:
logger.warning(f"get custom model schema failed, {ex}")
continue
try:
custom_model_schema = provider_instance.get_model_instance(
restrict_model.model_type
).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
except Exception as ex:
logger.warning(f"get custom model schema failed, {ex}")
continue
if not custom_model_schema:
continue
if not custom_model_schema:
continue
if custom_model_schema.model_type not in model_types:
continue
if custom_model_schema.model_type not in model_types:
continue
status = ModelStatus.ACTIVE
if (
custom_model_schema.model_type in model_setting_map
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
):
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
if model_setting.enabled is False:
status = ModelStatus.DISABLED
status = ModelStatus.ACTIVE
if (
custom_model_schema.model_type in model_setting_map
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
):
model_setting = model_setting_map[custom_model_schema.model_type][
custom_model_schema.model
]
if model_setting.enabled is False:
status = ModelStatus.DISABLED
provider_models.append(
ModelWithProviderEntity(
model=custom_model_schema.model,
label=custom_model_schema.label,
model_type=custom_model_schema.model_type,
features=custom_model_schema.features,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=status,
provider_models.append(
ModelWithProviderEntity(
model=custom_model_schema.model,
label=custom_model_schema.label,
model_type=custom_model_schema.model_type,
features=custom_model_schema.features,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=status,
)
)
)
# if llm name not in restricted llm list, remove it
restrict_model_names = [rm.model for rm in restrict_models]
for m in provider_models:
if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
m.status = ModelStatus.NO_PERMISSION
for model in provider_models:
if model.model_type == ModelType.LLM and m.model not in restrict_model_names:
model.status = ModelStatus.NO_PERMISSION
elif not quota_configuration.is_valid:
m.status = ModelStatus.QUOTA_EXCEEDED
model.status = ModelStatus.QUOTA_EXCEEDED
return provider_models
@@ -1043,7 +1049,7 @@ class ProviderConfigurations(BaseModel):
return iter(self.configurations)
def values(self) -> Iterator[ProviderConfiguration]:
return self.configurations.values()
return iter(self.configurations.values())
def get(self, key, default=None):
return self.configurations.get(key, default)