refactor: Improve model status handling and structured output (#20586)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -3,7 +3,9 @@ from collections import defaultdict
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
||||
@@ -393,19 +395,13 @@ class ProviderManager:
|
||||
|
||||
@staticmethod
|
||||
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
|
||||
"""
|
||||
Get all provider records of the workspace.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all()
|
||||
|
||||
provider_name_to_provider_records_dict = defaultdict(list)
|
||||
for provider in providers:
|
||||
# TODO: Use provider name with prefix after the data migration
|
||||
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
|
||||
providers = session.scalars(stmt)
|
||||
for provider in providers:
|
||||
# Use provider name with prefix after the data migration
|
||||
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
|
||||
return provider_name_to_provider_records_dict
|
||||
|
||||
@staticmethod
|
||||
@@ -416,17 +412,12 @@ class ProviderManager:
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
# Get all provider model records of the workspace
|
||||
provider_models = (
|
||||
db.session.query(ProviderModel)
|
||||
.filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
|
||||
.all()
|
||||
)
|
||||
|
||||
provider_name_to_provider_model_records_dict = defaultdict(list)
|
||||
for provider_model in provider_models:
|
||||
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
|
||||
provider_models = session.scalars(stmt)
|
||||
for provider_model in provider_models:
|
||||
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
|
||||
return provider_name_to_provider_model_records_dict
|
||||
|
||||
@staticmethod
|
||||
@@ -437,17 +428,14 @@ class ProviderManager:
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
preferred_provider_types = (
|
||||
db.session.query(TenantPreferredModelProvider)
|
||||
.filter(TenantPreferredModelProvider.tenant_id == tenant_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
provider_name_to_preferred_provider_type_records_dict = {
|
||||
preferred_provider_type.provider_name: preferred_provider_type
|
||||
for preferred_provider_type in preferred_provider_types
|
||||
}
|
||||
|
||||
provider_name_to_preferred_provider_type_records_dict = {}
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id)
|
||||
preferred_provider_types = session.scalars(stmt)
|
||||
provider_name_to_preferred_provider_type_records_dict = {
|
||||
preferred_provider_type.provider_name: preferred_provider_type
|
||||
for preferred_provider_type in preferred_provider_types
|
||||
}
|
||||
return provider_name_to_preferred_provider_type_records_dict
|
||||
|
||||
@staticmethod
|
||||
@@ -458,18 +446,14 @@ class ProviderManager:
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
provider_model_settings = (
|
||||
db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
provider_name_to_provider_model_settings_dict = defaultdict(list)
|
||||
for provider_model_setting in provider_model_settings:
|
||||
(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id)
|
||||
provider_model_settings = session.scalars(stmt)
|
||||
for provider_model_setting in provider_model_settings:
|
||||
provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append(
|
||||
provider_model_setting
|
||||
)
|
||||
)
|
||||
|
||||
return provider_name_to_provider_model_settings_dict
|
||||
|
||||
@staticmethod
|
||||
@@ -492,15 +476,14 @@ class ProviderManager:
|
||||
if not model_load_balancing_enabled:
|
||||
return {}
|
||||
|
||||
provider_load_balancing_configs = (
|
||||
db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
|
||||
for provider_load_balancing_config in provider_load_balancing_configs:
|
||||
provider_name_to_provider_load_balancing_model_configs_dict[
|
||||
provider_load_balancing_config.provider_name
|
||||
].append(provider_load_balancing_config)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id)
|
||||
provider_load_balancing_configs = session.scalars(stmt)
|
||||
for provider_load_balancing_config in provider_load_balancing_configs:
|
||||
provider_name_to_provider_load_balancing_model_configs_dict[
|
||||
provider_load_balancing_config.provider_name
|
||||
].append(provider_load_balancing_config)
|
||||
|
||||
return provider_name_to_provider_load_balancing_model_configs_dict
|
||||
|
||||
@@ -626,10 +609,9 @@ class ProviderManager:
|
||||
if not cached_provider_credentials:
|
||||
try:
|
||||
# fix origin data
|
||||
if (
|
||||
custom_provider_record.encrypted_config
|
||||
and not custom_provider_record.encrypted_config.startswith("{")
|
||||
):
|
||||
if custom_provider_record.encrypted_config is None:
|
||||
raise ValueError("No credentials found")
|
||||
if not custom_provider_record.encrypted_config.startswith("{"):
|
||||
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
|
||||
else:
|
||||
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
||||
@@ -733,7 +715,7 @@ class ProviderManager:
|
||||
return SystemConfiguration(enabled=False)
|
||||
|
||||
# Convert provider_records to dict
|
||||
quota_type_to_provider_records_dict = {}
|
||||
quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {}
|
||||
for provider_record in provider_records:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM.value:
|
||||
continue
|
||||
@@ -758,6 +740,11 @@ class ProviderManager:
|
||||
else:
|
||||
provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type]
|
||||
|
||||
if provider_record.quota_used is None:
|
||||
raise ValueError("quota_used is None")
|
||||
if provider_record.quota_limit is None:
|
||||
raise ValueError("quota_limit is None")
|
||||
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
@@ -791,10 +778,9 @@ class ProviderManager:
|
||||
cached_provider_credentials = provider_credentials_cache.get()
|
||||
|
||||
if not cached_provider_credentials:
|
||||
try:
|
||||
provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
provider_credentials = {}
|
||||
provider_credentials: dict[str, Any] = {}
|
||||
if provider_records and provider_records[0].encrypted_config:
|
||||
provider_credentials = json.loads(provider_records[0].encrypted_config)
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
|
Reference in New Issue
Block a user