feat: mypy for all type check (#10921)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
@@ -72,7 +73,7 @@ class DefaultModelProviderEntity(BaseModel):
|
||||
label: I18nObject
|
||||
icon_small: Optional[I18nObject] = None
|
||||
icon_large: Optional[I18nObject] = None
|
||||
supported_model_types: list[ModelType]
|
||||
supported_model_types: Sequence[ModelType] = []
|
||||
|
||||
|
||||
class DefaultModelEntity(BaseModel):
|
||||
|
@@ -40,7 +40,7 @@ from models.provider import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
original_provider_configurate_methods = {}
|
||||
original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {}
|
||||
|
||||
|
||||
class ProviderConfiguration(BaseModel):
|
||||
@@ -99,7 +99,8 @@ class ProviderConfiguration(BaseModel):
|
||||
continue
|
||||
|
||||
restrict_models = quota_configuration.restrict_models
|
||||
|
||||
if self.system_configuration.credentials is None:
|
||||
return None
|
||||
copy_credentials = self.system_configuration.credentials.copy()
|
||||
if restrict_models:
|
||||
for restrict_model in restrict_models:
|
||||
@@ -124,7 +125,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return credentials
|
||||
|
||||
def get_system_configuration_status(self) -> SystemConfigurationStatus:
|
||||
def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]:
|
||||
"""
|
||||
Get system configuration status.
|
||||
:return:
|
||||
@@ -136,6 +137,8 @@ class ProviderConfiguration(BaseModel):
|
||||
current_quota_configuration = next(
|
||||
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
|
||||
)
|
||||
if current_quota_configuration is None:
|
||||
return None
|
||||
|
||||
return (
|
||||
SystemConfigurationStatus.ACTIVE
|
||||
@@ -150,7 +153,7 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
|
||||
|
||||
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||
def get_custom_credentials(self, obfuscated: bool = False):
|
||||
"""
|
||||
Get custom credentials.
|
||||
|
||||
@@ -172,7 +175,7 @@ class ProviderConfiguration(BaseModel):
|
||||
else [],
|
||||
)
|
||||
|
||||
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
|
||||
def custom_credentials_validate(self, credentials: dict) -> tuple[Optional[Provider], dict]:
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param credentials: provider credentials
|
||||
@@ -324,7 +327,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
def custom_model_credentials_validate(
|
||||
self, model_type: ModelType, model: str, credentials: dict
|
||||
) -> tuple[ProviderModel, dict]:
|
||||
) -> tuple[Optional[ProviderModel], dict]:
|
||||
"""
|
||||
Validate custom model credentials.
|
||||
|
||||
@@ -740,10 +743,10 @@ class ProviderConfiguration(BaseModel):
|
||||
if model_type:
|
||||
model_types.append(model_type)
|
||||
else:
|
||||
model_types = provider_instance.get_provider_schema().supported_model_types
|
||||
model_types = list(provider_instance.get_provider_schema().supported_model_types)
|
||||
|
||||
# Group model settings by model type and model
|
||||
model_setting_map = defaultdict(dict)
|
||||
model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
|
||||
for model_setting in self.model_settings:
|
||||
model_setting_map[model_setting.model_type][model_setting.model] = model_setting
|
||||
|
||||
@@ -822,54 +825,57 @@ class ProviderConfiguration(BaseModel):
|
||||
]:
|
||||
# only customizable model
|
||||
for restrict_model in restrict_models:
|
||||
copy_credentials = self.system_configuration.credentials.copy()
|
||||
if restrict_model.base_model_name:
|
||||
copy_credentials["base_model_name"] = restrict_model.base_model_name
|
||||
if self.system_configuration.credentials is not None:
|
||||
copy_credentials = self.system_configuration.credentials.copy()
|
||||
if restrict_model.base_model_name:
|
||||
copy_credentials["base_model_name"] = restrict_model.base_model_name
|
||||
|
||||
try:
|
||||
custom_model_schema = provider_instance.get_model_instance(
|
||||
restrict_model.model_type
|
||||
).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
|
||||
except Exception as ex:
|
||||
logger.warning(f"get custom model schema failed, {ex}")
|
||||
continue
|
||||
try:
|
||||
custom_model_schema = provider_instance.get_model_instance(
|
||||
restrict_model.model_type
|
||||
).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
|
||||
except Exception as ex:
|
||||
logger.warning(f"get custom model schema failed, {ex}")
|
||||
continue
|
||||
|
||||
if not custom_model_schema:
|
||||
continue
|
||||
if not custom_model_schema:
|
||||
continue
|
||||
|
||||
if custom_model_schema.model_type not in model_types:
|
||||
continue
|
||||
if custom_model_schema.model_type not in model_types:
|
||||
continue
|
||||
|
||||
status = ModelStatus.ACTIVE
|
||||
if (
|
||||
custom_model_schema.model_type in model_setting_map
|
||||
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
|
||||
):
|
||||
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
|
||||
if model_setting.enabled is False:
|
||||
status = ModelStatus.DISABLED
|
||||
status = ModelStatus.ACTIVE
|
||||
if (
|
||||
custom_model_schema.model_type in model_setting_map
|
||||
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
|
||||
):
|
||||
model_setting = model_setting_map[custom_model_schema.model_type][
|
||||
custom_model_schema.model
|
||||
]
|
||||
if model_setting.enabled is False:
|
||||
status = ModelStatus.DISABLED
|
||||
|
||||
provider_models.append(
|
||||
ModelWithProviderEntity(
|
||||
model=custom_model_schema.model,
|
||||
label=custom_model_schema.label,
|
||||
model_type=custom_model_schema.model_type,
|
||||
features=custom_model_schema.features,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties=custom_model_schema.model_properties,
|
||||
deprecated=custom_model_schema.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=status,
|
||||
provider_models.append(
|
||||
ModelWithProviderEntity(
|
||||
model=custom_model_schema.model,
|
||||
label=custom_model_schema.label,
|
||||
model_type=custom_model_schema.model_type,
|
||||
features=custom_model_schema.features,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties=custom_model_schema.model_properties,
|
||||
deprecated=custom_model_schema.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=status,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# if llm name not in restricted llm list, remove it
|
||||
restrict_model_names = [rm.model for rm in restrict_models]
|
||||
for m in provider_models:
|
||||
if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
|
||||
m.status = ModelStatus.NO_PERMISSION
|
||||
for model in provider_models:
|
||||
if model.model_type == ModelType.LLM and m.model not in restrict_model_names:
|
||||
model.status = ModelStatus.NO_PERMISSION
|
||||
elif not quota_configuration.is_valid:
|
||||
m.status = ModelStatus.QUOTA_EXCEEDED
|
||||
model.status = ModelStatus.QUOTA_EXCEEDED
|
||||
|
||||
return provider_models
|
||||
|
||||
@@ -1043,7 +1049,7 @@ class ProviderConfigurations(BaseModel):
|
||||
return iter(self.configurations)
|
||||
|
||||
def values(self) -> Iterator[ProviderConfiguration]:
|
||||
return self.configurations.values()
|
||||
return iter(self.configurations.values())
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self.configurations.get(key, default)
|
||||
|
Reference in New Issue
Block a user