feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -2,7 +2,7 @@ import datetime
import json
import logging
from json import JSONDecodeError
from typing import Optional
from typing import Optional, Union
from constants import HIDDEN_VALUE
from core.entities.provider_configuration import ProviderConfiguration
@@ -88,11 +88,11 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type = ModelType.value_of(model_type)
model_type_enum = ModelType.value_of(model_type)
# Get provider model setting
provider_model_setting = provider_configuration.get_provider_model_setting(
model_type=model_type,
model_type=model_type_enum,
model=model,
)
@@ -106,7 +106,7 @@ class ModelLoadBalancingService:
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
.order_by(LoadBalancingModelConfig.created_at)
@@ -124,7 +124,7 @@ class ModelLoadBalancingService:
if not inherit_config_exists:
# Initialize the inherit configuration
inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type)
inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type_enum)
# prepend the inherit configuration
load_balancing_configs.insert(0, inherit_config)
@@ -148,7 +148,7 @@ class ModelLoadBalancingService:
tenant_id=tenant_id,
provider=provider,
model=model,
model_type=model_type,
model_type=model_type_enum,
config_id=load_balancing_config.id,
)
@@ -214,7 +214,7 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type = ModelType.value_of(model_type)
model_type_enum = ModelType.value_of(model_type)
# Get load balancing configurations
load_balancing_model_config = (
@@ -222,7 +222,7 @@ class ModelLoadBalancingService:
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
@@ -300,7 +300,7 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type = ModelType.value_of(model_type)
model_type_enum = ModelType.value_of(model_type)
if not isinstance(configs, list):
raise ValueError("Invalid load balancing configs")
@@ -310,7 +310,7 @@ class ModelLoadBalancingService:
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
.all()
@@ -359,7 +359,7 @@ class ModelLoadBalancingService:
credentials = self._custom_credentials_validate(
tenant_id=tenant_id,
provider_configuration=provider_configuration,
model_type=model_type,
model_type=model_type_enum,
model=model,
credentials=credentials,
load_balancing_model_config=load_balancing_config,
@@ -395,7 +395,7 @@ class ModelLoadBalancingService:
credentials = self._custom_credentials_validate(
tenant_id=tenant_id,
provider_configuration=provider_configuration,
model_type=model_type,
model_type=model_type_enum,
model=model,
credentials=credentials,
validate=False,
@@ -405,7 +405,7 @@ class ModelLoadBalancingService:
load_balancing_model_config = LoadBalancingModelConfig(
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
model_type=model_type.to_origin_model_type(),
model_type=model_type_enum.to_origin_model_type(),
model_name=model,
name=name,
encrypted_config=json.dumps(credentials),
@@ -450,7 +450,7 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type = ModelType.value_of(model_type)
model_type_enum = ModelType.value_of(model_type)
load_balancing_model_config = None
if config_id:
@@ -460,7 +460,7 @@ class ModelLoadBalancingService:
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
@@ -474,7 +474,7 @@ class ModelLoadBalancingService:
self._custom_credentials_validate(
tenant_id=tenant_id,
provider_configuration=provider_configuration,
model_type=model_type,
model_type=model_type_enum,
model=model,
credentials=credentials,
load_balancing_model_config=load_balancing_model_config,
@@ -547,19 +547,14 @@ class ModelLoadBalancingService:
def _get_credential_schema(
self, provider_configuration: ProviderConfiguration
) -> ModelCredentialSchema | ProviderCredentialSchema:
"""
Get form schemas.
:param provider_configuration: provider configuration
:return:
"""
# Get credential form schemas from model credential schema or provider credential schema
) -> Union[ModelCredentialSchema, ProviderCredentialSchema]:
"""Get form schemas."""
if provider_configuration.provider.model_credential_schema:
credential_schema = provider_configuration.provider.model_credential_schema
return provider_configuration.provider.model_credential_schema
elif provider_configuration.provider.provider_credential_schema:
return provider_configuration.provider.provider_credential_schema
else:
credential_schema = provider_configuration.provider.provider_credential_schema
return credential_schema
raise ValueError("No credential schema found")
def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None:
"""