feat: support pinning, including, and excluding for model providers and tools (#7419)

Co-authored-by: GareArc <chen4851@purude.edu>
This commit is contained in:
Xiyuan Chen
2024-08-20 23:16:43 -04:00
committed by GitHub
parent 6c25d7bed3
commit 4e7b6aec3a
14 changed files with 363 additions and 57 deletions

View File

@@ -5,6 +5,7 @@ from typing import Optional
from sqlalchemy.exc import IntegrityError
from configs import dify_config
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
from core.entities.provider_entities import (
@@ -18,12 +19,9 @@ from core.entities.provider_entities import (
)
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.helper.position_helper import is_filtered
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import (
CredentialFormSchema,
FormType,
ProviderEntity,
)
from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity
from core.model_runtime.model_providers import model_provider_factory
from extensions import ext_hosting_provider
from extensions.ext_database import db
@@ -45,6 +43,7 @@ class ProviderManager:
"""
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
"""
def __init__(self) -> None:
self.decoding_rsa_key = None
self.decoding_cipher_rsa = None
@@ -117,6 +116,16 @@ class ProviderManager:
# Construct ProviderConfiguration objects for each provider
for provider_entity in provider_entities:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
data=provider_entity,
name_func=lambda x: x.provider,
):
continue
provider_name = provider_entity.provider
provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, [])
provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, [])
@@ -271,6 +280,24 @@ class ProviderManager:
)
)
def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
"""
Get names of first model and its provider
:param tenant_id: workspace id
:param model_type: model type
:return: provider name, model name
"""
provider_configurations = self.get_configurations(tenant_id)
# get available models from provider_configurations
all_models = provider_configurations.get_models(
model_type=model_type,
only_active=False
)
return all_models[0].provider.provider, all_models[0].model
def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \
-> TenantDefaultModel:
"""