feat: mypy for all type check (#10921)
This commit is contained in:
@@ -2,7 +2,7 @@ import datetime
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.entities.provider_configuration import ProviderConfiguration
|
||||
@@ -88,11 +88,11 @@ class ModelLoadBalancingService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Convert model type to ModelType
|
||||
model_type = ModelType.value_of(model_type)
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
|
||||
# Get provider model setting
|
||||
provider_model_setting = provider_configuration.get_provider_model_setting(
|
||||
model_type=model_type,
|
||||
model_type=model_type_enum,
|
||||
model=model,
|
||||
)
|
||||
|
||||
@@ -106,7 +106,7 @@ class ModelLoadBalancingService:
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
)
|
||||
.order_by(LoadBalancingModelConfig.created_at)
|
||||
@@ -124,7 +124,7 @@ class ModelLoadBalancingService:
|
||||
|
||||
if not inherit_config_exists:
|
||||
# Initialize the inherit configuration
|
||||
inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type)
|
||||
inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type_enum)
|
||||
|
||||
# prepend the inherit configuration
|
||||
load_balancing_configs.insert(0, inherit_config)
|
||||
@@ -148,7 +148,7 @@ class ModelLoadBalancingService:
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=model,
|
||||
model_type=model_type,
|
||||
model_type=model_type_enum,
|
||||
config_id=load_balancing_config.id,
|
||||
)
|
||||
|
||||
@@ -214,7 +214,7 @@ class ModelLoadBalancingService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Convert model type to ModelType
|
||||
model_type = ModelType.value_of(model_type)
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
|
||||
# Get load balancing configurations
|
||||
load_balancing_model_config = (
|
||||
@@ -222,7 +222,7 @@ class ModelLoadBalancingService:
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
LoadBalancingModelConfig.id == config_id,
|
||||
)
|
||||
@@ -300,7 +300,7 @@ class ModelLoadBalancingService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Convert model type to ModelType
|
||||
model_type = ModelType.value_of(model_type)
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
|
||||
if not isinstance(configs, list):
|
||||
raise ValueError("Invalid load balancing configs")
|
||||
@@ -310,7 +310,7 @@ class ModelLoadBalancingService:
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
)
|
||||
.all()
|
||||
@@ -359,7 +359,7 @@ class ModelLoadBalancingService:
|
||||
credentials = self._custom_credentials_validate(
|
||||
tenant_id=tenant_id,
|
||||
provider_configuration=provider_configuration,
|
||||
model_type=model_type,
|
||||
model_type=model_type_enum,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
load_balancing_model_config=load_balancing_config,
|
||||
@@ -395,7 +395,7 @@ class ModelLoadBalancingService:
|
||||
credentials = self._custom_credentials_validate(
|
||||
tenant_id=tenant_id,
|
||||
provider_configuration=provider_configuration,
|
||||
model_type=model_type,
|
||||
model_type=model_type_enum,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
validate=False,
|
||||
@@ -405,7 +405,7 @@ class ModelLoadBalancingService:
|
||||
load_balancing_model_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type_enum.to_origin_model_type(),
|
||||
model_name=model,
|
||||
name=name,
|
||||
encrypted_config=json.dumps(credentials),
|
||||
@@ -450,7 +450,7 @@ class ModelLoadBalancingService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Convert model type to ModelType
|
||||
model_type = ModelType.value_of(model_type)
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
|
||||
load_balancing_model_config = None
|
||||
if config_id:
|
||||
@@ -460,7 +460,7 @@ class ModelLoadBalancingService:
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
LoadBalancingModelConfig.id == config_id,
|
||||
)
|
||||
@@ -474,7 +474,7 @@ class ModelLoadBalancingService:
|
||||
self._custom_credentials_validate(
|
||||
tenant_id=tenant_id,
|
||||
provider_configuration=provider_configuration,
|
||||
model_type=model_type,
|
||||
model_type=model_type_enum,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
load_balancing_model_config=load_balancing_model_config,
|
||||
@@ -547,19 +547,14 @@ class ModelLoadBalancingService:
|
||||
|
||||
def _get_credential_schema(
|
||||
self, provider_configuration: ProviderConfiguration
|
||||
) -> ModelCredentialSchema | ProviderCredentialSchema:
|
||||
"""
|
||||
Get form schemas.
|
||||
:param provider_configuration: provider configuration
|
||||
:return:
|
||||
"""
|
||||
# Get credential form schemas from model credential schema or provider credential schema
|
||||
) -> Union[ModelCredentialSchema, ProviderCredentialSchema]:
|
||||
"""Get form schemas."""
|
||||
if provider_configuration.provider.model_credential_schema:
|
||||
credential_schema = provider_configuration.provider.model_credential_schema
|
||||
return provider_configuration.provider.model_credential_schema
|
||||
elif provider_configuration.provider.provider_credential_schema:
|
||||
return provider_configuration.provider.provider_credential_schema
|
||||
else:
|
||||
credential_schema = provider_configuration.provider.provider_credential_schema
|
||||
|
||||
return credential_schema
|
||||
raise ValueError("No credential schema found")
|
||||
|
||||
def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user