feat: improve multi model credentials (#25009)
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -67,7 +67,7 @@ class ModelProviderCredentialApi(Resource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@@ -94,7 +94,7 @@ class ModelProviderCredentialApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
@@ -219,7 +219,11 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
|
||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
config_from=args.get("config_from", ""),
|
||||
)
|
||||
|
||||
if args.get("config_from", "") == "predefined-model":
|
||||
@@ -263,7 +267,7 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -309,7 +313,7 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
)
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
@@ -343,7 +344,65 @@ class ProviderConfiguration(BaseModel):
|
||||
with Session(db.engine) as new_session:
|
||||
return _validate(new_session)
|
||||
|
||||
def create_provider_credential(self, credentials: dict, credential_name: str) -> None:
|
||||
def _generate_provider_credential_name(self, session) -> str:
|
||||
"""
|
||||
Generate a unique credential name for provider.
|
||||
:return: credential name
|
||||
"""
|
||||
return self._generate_next_api_key_name(
|
||||
session=session,
|
||||
query_factory=lambda: select(ProviderCredential).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
),
|
||||
)
|
||||
|
||||
def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str:
|
||||
"""
|
||||
Generate a unique credential name for custom model.
|
||||
:return: credential name
|
||||
"""
|
||||
return self._generate_next_api_key_name(
|
||||
session=session,
|
||||
query_factory=lambda: select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
),
|
||||
)
|
||||
|
||||
def _generate_next_api_key_name(self, session, query_factory) -> str:
|
||||
"""
|
||||
Generate next available API KEY name by finding the highest numbered suffix.
|
||||
:param session: database session
|
||||
:param query_factory: function that returns the SQLAlchemy query
|
||||
:return: next available API KEY name
|
||||
"""
|
||||
try:
|
||||
stmt = query_factory()
|
||||
credential_records = session.execute(stmt).scalars().all()
|
||||
|
||||
if not credential_records:
|
||||
return "API KEY 1"
|
||||
|
||||
# Extract numbers from API KEY pattern using list comprehension
|
||||
pattern = re.compile(r"^API KEY\s+(\d+)$")
|
||||
numbers = [
|
||||
int(match.group(1))
|
||||
for cr in credential_records
|
||||
if cr.credential_name and (match := pattern.match(cr.credential_name.strip()))
|
||||
]
|
||||
|
||||
# Return next sequential number
|
||||
next_number = max(numbers, default=0) + 1
|
||||
return f"API KEY {next_number}"
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Error generating next credential name: %s", str(e))
|
||||
return "API KEY 1"
|
||||
|
||||
def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None:
|
||||
"""
|
||||
Add custom provider credentials.
|
||||
:param credentials: provider credentials
|
||||
@@ -351,8 +410,12 @@ class ProviderConfiguration(BaseModel):
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
|
||||
if credential_name and self._check_provider_credential_name_exists(
|
||||
credential_name=credential_name, session=session
|
||||
):
|
||||
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
||||
else:
|
||||
credential_name = self._generate_provider_credential_name(session)
|
||||
|
||||
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
|
||||
provider_record = self._get_provider_record(session)
|
||||
@@ -395,7 +458,7 @@ class ProviderConfiguration(BaseModel):
|
||||
self,
|
||||
credentials: dict,
|
||||
credential_id: str,
|
||||
credential_name: str,
|
||||
credential_name: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
update a saved provider credential (by credential_id).
|
||||
@@ -406,7 +469,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
if self._check_provider_credential_name_exists(
|
||||
if credential_name and self._check_provider_credential_name_exists(
|
||||
credential_name=credential_name, session=session, exclude_id=credential_id
|
||||
):
|
||||
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
||||
@@ -428,9 +491,9 @@ class ProviderConfiguration(BaseModel):
|
||||
try:
|
||||
# Update credential
|
||||
credential_record.encrypted_config = json.dumps(credentials)
|
||||
credential_record.credential_name = credential_name
|
||||
credential_record.updated_at = naive_utc_now()
|
||||
|
||||
if credential_name:
|
||||
credential_record.credential_name = credential_name
|
||||
session.commit()
|
||||
|
||||
if provider_record and provider_record.credential_id == credential_id:
|
||||
@@ -532,13 +595,7 @@ class ProviderConfiguration(BaseModel):
|
||||
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
|
||||
)
|
||||
lb_credentials_cache.delete()
|
||||
|
||||
lb_config.credential_id = None
|
||||
lb_config.encrypted_config = None
|
||||
lb_config.enabled = False
|
||||
lb_config.name = "__delete__"
|
||||
lb_config.updated_at = naive_utc_now()
|
||||
session.add(lb_config)
|
||||
session.delete(lb_config)
|
||||
|
||||
# Check if this is the currently active credential
|
||||
provider_record = self._get_provider_record(session)
|
||||
@@ -822,7 +879,7 @@ class ProviderConfiguration(BaseModel):
|
||||
return _validate(new_session)
|
||||
|
||||
def create_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
|
||||
) -> None:
|
||||
"""
|
||||
Create a custom model credential.
|
||||
@@ -833,10 +890,14 @@ class ProviderConfiguration(BaseModel):
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
if self._check_custom_model_credential_name_exists(
|
||||
if credential_name and self._check_custom_model_credential_name_exists(
|
||||
model=model, model_type=model_type, credential_name=credential_name, session=session
|
||||
):
|
||||
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
|
||||
else:
|
||||
credential_name = self._generate_custom_model_credential_name(
|
||||
model=model, model_type=model_type, session=session
|
||||
)
|
||||
# validate custom model config
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type, model=model, credentials=credentials, session=session
|
||||
@@ -880,7 +941,7 @@ class ProviderConfiguration(BaseModel):
|
||||
raise
|
||||
|
||||
def update_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str, credential_id: str
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Update a custom model credential.
|
||||
@@ -893,7 +954,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
if self._check_custom_model_credential_name_exists(
|
||||
if credential_name and self._check_custom_model_credential_name_exists(
|
||||
model=model,
|
||||
model_type=model_type,
|
||||
credential_name=credential_name,
|
||||
@@ -925,8 +986,9 @@ class ProviderConfiguration(BaseModel):
|
||||
try:
|
||||
# Update credential
|
||||
credential_record.encrypted_config = json.dumps(credentials)
|
||||
credential_record.credential_name = credential_name
|
||||
credential_record.updated_at = naive_utc_now()
|
||||
if credential_name:
|
||||
credential_record.credential_name = credential_name
|
||||
session.commit()
|
||||
|
||||
if provider_model_record and provider_model_record.credential_id == credential_id:
|
||||
@@ -982,12 +1044,7 @@ class ProviderConfiguration(BaseModel):
|
||||
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
|
||||
)
|
||||
lb_credentials_cache.delete()
|
||||
lb_config.credential_id = None
|
||||
lb_config.encrypted_config = None
|
||||
lb_config.enabled = False
|
||||
lb_config.name = "__delete__"
|
||||
lb_config.updated_at = naive_utc_now()
|
||||
session.add(lb_config)
|
||||
session.delete(lb_config)
|
||||
|
||||
# Check if this is the currently active credential
|
||||
provider_model_record = self._get_custom_model_record(model_type, model, session=session)
|
||||
@@ -1054,6 +1111,7 @@ class ProviderConfiguration(BaseModel):
|
||||
provider_name=self.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
is_valid=True,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
else:
|
||||
@@ -1605,11 +1663,9 @@ class ProviderConfiguration(BaseModel):
|
||||
if config.credential_source_type != "custom_model"
|
||||
]
|
||||
|
||||
if len(provider_model_lb_configs) > 1:
|
||||
load_balancing_enabled = True
|
||||
|
||||
if any(config.name == "__delete__" for config in provider_model_lb_configs):
|
||||
has_invalid_load_balancing_configs = True
|
||||
load_balancing_enabled = model_setting.load_balancing_enabled
|
||||
# when the user enable load_balancing but available configs are less than 2 display warning
|
||||
has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2
|
||||
|
||||
provider_models.append(
|
||||
ModelWithProviderEntity(
|
||||
@@ -1631,6 +1687,8 @@ class ProviderConfiguration(BaseModel):
|
||||
for model_configuration in self.custom_configuration.models:
|
||||
if model_configuration.model_type not in model_types:
|
||||
continue
|
||||
if model_configuration.unadded_to_model_list:
|
||||
continue
|
||||
if model and model != model_configuration.model:
|
||||
continue
|
||||
try:
|
||||
@@ -1663,11 +1721,9 @@ class ProviderConfiguration(BaseModel):
|
||||
if config.credential_source_type != "provider"
|
||||
]
|
||||
|
||||
if len(custom_model_lb_configs) > 1:
|
||||
load_balancing_enabled = True
|
||||
|
||||
if any(config.name == "__delete__" for config in custom_model_lb_configs):
|
||||
has_invalid_load_balancing_configs = True
|
||||
load_balancing_enabled = model_setting.load_balancing_enabled
|
||||
# when the user enable load_balancing but available configs are less than 2 display warning
|
||||
has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2
|
||||
|
||||
if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials:
|
||||
status = ModelStatus.CREDENTIAL_REMOVED
|
||||
|
@@ -111,11 +111,21 @@ class CustomModelConfiguration(BaseModel):
|
||||
current_credential_id: Optional[str] = None
|
||||
current_credential_name: Optional[str] = None
|
||||
available_model_credentials: list[CredentialConfiguration] = []
|
||||
unadded_to_model_list: Optional[bool] = False
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class UnaddedModelConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider unadded model configuration.
|
||||
"""
|
||||
|
||||
model: str
|
||||
model_type: ModelType
|
||||
|
||||
|
||||
class CustomConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider custom configuration.
|
||||
@@ -123,6 +133,7 @@ class CustomConfiguration(BaseModel):
|
||||
|
||||
provider: Optional[CustomProviderConfiguration] = None
|
||||
models: list[CustomModelConfiguration] = []
|
||||
can_added_models: list[UnaddedModelConfiguration] = []
|
||||
|
||||
|
||||
class ModelLoadBalancingConfiguration(BaseModel):
|
||||
@@ -144,6 +155,7 @@ class ModelSettings(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
enabled: bool = True
|
||||
load_balancing_enabled: bool = False
|
||||
load_balancing_configs: list[ModelLoadBalancingConfiguration] = []
|
||||
|
||||
# pydantic configs
|
||||
|
@@ -1,8 +1,9 @@
|
||||
import contextlib
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@@ -22,6 +23,7 @@ from core.entities.provider_entities import (
|
||||
QuotaConfiguration,
|
||||
QuotaUnit,
|
||||
SystemConfiguration,
|
||||
UnaddedModelConfiguration,
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
@@ -537,6 +539,23 @@ class ProviderManager:
|
||||
for credential in available_credentials
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_credentials_from_provider_model(tenant_id: str, provider_name: str) -> Sequence[ProviderModelCredential]:
|
||||
"""
|
||||
Get all the credentials records from ProviderModelCredential by provider_name
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider_name: provider name
|
||||
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.tenant_id == tenant_id, ProviderModelCredential.provider_name == provider_name
|
||||
)
|
||||
|
||||
all_credentials = session.scalars(stmt).all()
|
||||
return all_credentials
|
||||
|
||||
@staticmethod
|
||||
def _init_trial_provider_records(
|
||||
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
|
||||
@@ -623,6 +642,44 @@ class ProviderManager:
|
||||
:param provider_model_records: provider model records
|
||||
:return:
|
||||
"""
|
||||
# Get custom provider configuration
|
||||
custom_provider_configuration = self._get_custom_provider_configuration(
|
||||
tenant_id, provider_entity, provider_records
|
||||
)
|
||||
|
||||
# Get all model credentials once
|
||||
all_model_credentials = self.get_credentials_from_provider_model(tenant_id, provider_entity.provider)
|
||||
|
||||
# Get custom models which have not been added to the model list yet
|
||||
unadded_models = self._get_can_added_models(provider_model_records, all_model_credentials)
|
||||
|
||||
# Get custom model configurations
|
||||
custom_model_configurations = self._get_custom_model_configurations(
|
||||
tenant_id, provider_entity, provider_model_records, unadded_models, all_model_credentials
|
||||
)
|
||||
|
||||
can_added_models = [
|
||||
UnaddedModelConfiguration(model=model["model"], model_type=model["model_type"]) for model in unadded_models
|
||||
]
|
||||
|
||||
return CustomConfiguration(
|
||||
provider=custom_provider_configuration,
|
||||
models=custom_model_configurations,
|
||||
can_added_models=can_added_models,
|
||||
)
|
||||
|
||||
def _get_custom_provider_configuration(
|
||||
self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider]
|
||||
) -> CustomProviderConfiguration | None:
|
||||
"""Get custom provider configuration."""
|
||||
# Find custom provider record (non-system)
|
||||
custom_provider_record = next(
|
||||
(record for record in provider_records if record.provider_type != ProviderType.SYSTEM.value), None
|
||||
)
|
||||
|
||||
if not custom_provider_record:
|
||||
return None
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.provider_credential_schema.credential_form_schemas
|
||||
@@ -630,57 +687,17 @@ class ProviderManager:
|
||||
else []
|
||||
)
|
||||
|
||||
# Get custom provider record
|
||||
custom_provider_record = None
|
||||
for provider_record in provider_records:
|
||||
if provider_record.provider_type == ProviderType.SYSTEM.value:
|
||||
continue
|
||||
|
||||
custom_provider_record = provider_record
|
||||
|
||||
# Get custom provider credentials
|
||||
custom_provider_configuration = None
|
||||
if custom_provider_record:
|
||||
provider_credentials_cache = ProviderCredentialsCache(
|
||||
# Get and decrypt provider credentials
|
||||
provider_credentials = self._get_and_decrypt_credentials(
|
||||
tenant_id=tenant_id,
|
||||
identity_id=custom_provider_record.id,
|
||||
record_id=custom_provider_record.id,
|
||||
encrypted_config=custom_provider_record.encrypted_config,
|
||||
secret_variables=provider_credential_secret_variables,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||
is_provider=True,
|
||||
)
|
||||
|
||||
# Get cached provider credentials
|
||||
cached_provider_credentials = provider_credentials_cache.get()
|
||||
|
||||
if not cached_provider_credentials:
|
||||
try:
|
||||
# fix origin data
|
||||
if custom_provider_record.encrypted_config is None:
|
||||
provider_credentials = {}
|
||||
elif not custom_provider_record.encrypted_config.startswith("{"):
|
||||
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
|
||||
else:
|
||||
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
provider_credentials = {}
|
||||
|
||||
# Get decoding rsa key and cipher for decrypting credentials
|
||||
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
||||
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||
|
||||
for variable in provider_credential_secret_variables:
|
||||
if variable in provider_credentials:
|
||||
with contextlib.suppress(ValueError):
|
||||
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_credentials.get(variable) or "", # type: ignore
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa,
|
||||
)
|
||||
|
||||
# cache provider credentials
|
||||
provider_credentials_cache.set(credentials=provider_credentials)
|
||||
else:
|
||||
provider_credentials = cached_provider_credentials
|
||||
|
||||
custom_provider_configuration = CustomProviderConfiguration(
|
||||
return CustomProviderConfiguration(
|
||||
credentials=provider_credentials,
|
||||
current_credential_name=custom_provider_record.credential_name,
|
||||
current_credential_id=custom_provider_record.credential_id,
|
||||
@@ -689,54 +706,79 @@ class ProviderManager:
|
||||
),
|
||||
)
|
||||
|
||||
# Get provider model credential secret variables
|
||||
def _get_can_added_models(
|
||||
self, provider_model_records: list[ProviderModel], all_model_credentials: Sequence[ProviderModelCredential]
|
||||
) -> list[dict]:
|
||||
"""Get the custom models and credentials from enterprise version which haven't add to the model list"""
|
||||
existing_model_set = {(record.model_name, record.model_type) for record in provider_model_records}
|
||||
|
||||
# Get not added custom models credentials
|
||||
not_added_custom_models_credentials = [
|
||||
credential
|
||||
for credential in all_model_credentials
|
||||
if (credential.model_name, credential.model_type) not in existing_model_set
|
||||
]
|
||||
|
||||
# Group credentials by model
|
||||
model_to_credentials = defaultdict(list)
|
||||
for credential in not_added_custom_models_credentials:
|
||||
model_to_credentials[(credential.model_name, credential.model_type)].append(credential)
|
||||
|
||||
return [
|
||||
{
|
||||
"model": model_key[0],
|
||||
"model_type": ModelType.value_of(model_key[1]),
|
||||
"available_model_credentials": [
|
||||
CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
|
||||
for cred in creds
|
||||
],
|
||||
}
|
||||
for model_key, creds in model_to_credentials.items()
|
||||
]
|
||||
|
||||
def _get_custom_model_configurations(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider_entity: ProviderEntity,
|
||||
provider_model_records: list[ProviderModel],
|
||||
can_added_models: list[dict],
|
||||
all_model_credentials: Sequence[ProviderModelCredential],
|
||||
) -> list[CustomModelConfiguration]:
|
||||
"""Get custom model configurations."""
|
||||
# Get model credential secret variables
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.model_credential_schema.credential_form_schemas
|
||||
if provider_entity.model_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
# Get custom provider model credentials
|
||||
# Create credentials lookup for efficient access
|
||||
credentials_map = defaultdict(list)
|
||||
for credential in all_model_credentials:
|
||||
credentials_map[(credential.model_name, credential.model_type)].append(credential)
|
||||
|
||||
custom_model_configurations = []
|
||||
|
||||
# Process existing model records
|
||||
for provider_model_record in provider_model_records:
|
||||
available_model_credentials = self.get_provider_model_available_credentials(
|
||||
tenant_id,
|
||||
provider_model_record.provider_name,
|
||||
provider_model_record.model_name,
|
||||
provider_model_record.model_type,
|
||||
# Use pre-fetched credentials instead of individual database calls
|
||||
available_model_credentials = [
|
||||
CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
|
||||
for cred in credentials_map.get(
|
||||
(provider_model_record.model_name, provider_model_record.model_type), []
|
||||
)
|
||||
]
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL
|
||||
# Get and decrypt model credentials
|
||||
provider_model_credentials = self._get_and_decrypt_credentials(
|
||||
tenant_id=tenant_id,
|
||||
record_id=provider_model_record.id,
|
||||
encrypted_config=provider_model_record.encrypted_config,
|
||||
secret_variables=model_credential_secret_variables,
|
||||
cache_type=ProviderCredentialsCacheType.MODEL,
|
||||
is_provider=False,
|
||||
)
|
||||
|
||||
# Get cached provider model credentials
|
||||
cached_provider_model_credentials = provider_model_credentials_cache.get()
|
||||
|
||||
if not cached_provider_model_credentials and provider_model_record.encrypted_config:
|
||||
try:
|
||||
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Get decoding rsa key and cipher for decrypting credentials
|
||||
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
||||
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||
|
||||
for variable in model_credential_secret_variables:
|
||||
if variable in provider_model_credentials:
|
||||
with contextlib.suppress(ValueError):
|
||||
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_model_credentials.get(variable),
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa,
|
||||
)
|
||||
|
||||
# cache provider model credentials
|
||||
provider_model_credentials_cache.set(credentials=provider_model_credentials)
|
||||
else:
|
||||
provider_model_credentials = cached_provider_model_credentials
|
||||
|
||||
custom_model_configurations.append(
|
||||
CustomModelConfiguration(
|
||||
model=provider_model_record.model_name,
|
||||
@@ -748,7 +790,71 @@ class ProviderManager:
|
||||
)
|
||||
)
|
||||
|
||||
return CustomConfiguration(provider=custom_provider_configuration, models=custom_model_configurations)
|
||||
# Add models that can be added
|
||||
for model in can_added_models:
|
||||
custom_model_configurations.append(
|
||||
CustomModelConfiguration(
|
||||
model=model["model"],
|
||||
model_type=model["model_type"],
|
||||
credentials=None,
|
||||
current_credential_id=None,
|
||||
current_credential_name=None,
|
||||
available_model_credentials=model["available_model_credentials"],
|
||||
unadded_to_model_list=True,
|
||||
)
|
||||
)
|
||||
|
||||
return custom_model_configurations
|
||||
|
||||
def _get_and_decrypt_credentials(
|
||||
self,
|
||||
tenant_id: str,
|
||||
record_id: str,
|
||||
encrypted_config: str | None,
|
||||
secret_variables: list[str],
|
||||
cache_type: ProviderCredentialsCacheType,
|
||||
is_provider: bool = False,
|
||||
) -> dict:
|
||||
"""Get and decrypt credentials with caching."""
|
||||
credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
identity_id=record_id,
|
||||
cache_type=cache_type,
|
||||
)
|
||||
|
||||
# Try to get from cache first
|
||||
cached_credentials = credentials_cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
# Parse encrypted config
|
||||
if not encrypted_config:
|
||||
return {}
|
||||
|
||||
if is_provider and not encrypted_config.startswith("{"):
|
||||
return {"openai_api_key": encrypted_config}
|
||||
|
||||
try:
|
||||
credentials = cast(dict, json.loads(encrypted_config))
|
||||
except JSONDecodeError:
|
||||
return {}
|
||||
|
||||
# Decrypt secret variables
|
||||
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
||||
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||
|
||||
for variable in secret_variables:
|
||||
if variable in credentials:
|
||||
with contextlib.suppress(ValueError):
|
||||
credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
credentials.get(variable) or "",
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa,
|
||||
)
|
||||
|
||||
# Cache the decrypted credentials
|
||||
credentials_cache.set(credentials=credentials)
|
||||
return credentials
|
||||
|
||||
def _to_system_configuration(
|
||||
self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider]
|
||||
@@ -956,18 +1062,6 @@ class ProviderManager:
|
||||
load_balancing_model_config.model_name == provider_model_setting.model_name
|
||||
and load_balancing_model_config.model_type == provider_model_setting.model_type
|
||||
):
|
||||
if load_balancing_model_config.name == "__delete__":
|
||||
# to calculate current model whether has invalidate lb configs
|
||||
load_balancing_configs.append(
|
||||
ModelLoadBalancingConfiguration(
|
||||
id=load_balancing_model_config.id,
|
||||
name=load_balancing_model_config.name,
|
||||
credentials={},
|
||||
credential_source_type=load_balancing_model_config.credential_source_type,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if not load_balancing_model_config.enabled:
|
||||
continue
|
||||
|
||||
@@ -1033,6 +1127,7 @@ class ProviderManager:
|
||||
model=provider_model_setting.model_name,
|
||||
model_type=ModelType.value_of(provider_model_setting.model_type),
|
||||
enabled=provider_model_setting.enabled,
|
||||
load_balancing_enabled=provider_model_setting.load_balancing_enabled,
|
||||
load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [],
|
||||
)
|
||||
)
|
||||
|
@@ -13,6 +13,7 @@ from core.entities.provider_entities import (
|
||||
CustomModelConfiguration,
|
||||
ProviderQuotaType,
|
||||
QuotaConfiguration,
|
||||
UnaddedModelConfiguration,
|
||||
)
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
@@ -45,6 +46,7 @@ class CustomConfigurationResponse(BaseModel):
|
||||
current_credential_name: Optional[str] = None
|
||||
available_credentials: Optional[list[CredentialConfiguration]] = None
|
||||
custom_models: Optional[list[CustomModelConfiguration]] = None
|
||||
can_added_models: Optional[list[UnaddedModelConfiguration]] = None
|
||||
|
||||
|
||||
class SystemConfigurationResponse(BaseModel):
|
||||
|
@@ -3,6 +3,8 @@ import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import or_
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.entities.provider_configuration import ProviderConfiguration
|
||||
from core.helper import encrypter
|
||||
@@ -69,7 +71,7 @@ class ModelLoadBalancingService:
|
||||
provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
|
||||
|
||||
def get_load_balancing_configs(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str, config_from: str = ""
|
||||
) -> tuple[bool, list[dict]]:
|
||||
"""
|
||||
Get load balancing configurations.
|
||||
@@ -100,6 +102,11 @@ class ModelLoadBalancingService:
|
||||
if provider_model_setting and provider_model_setting.load_balancing_enabled:
|
||||
is_load_balancing_enabled = True
|
||||
|
||||
if config_from == "predefined-model":
|
||||
credential_source_type = "provider"
|
||||
else:
|
||||
credential_source_type = "custom_model"
|
||||
|
||||
# Get load balancing configurations
|
||||
load_balancing_configs = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
@@ -108,6 +115,10 @@ class ModelLoadBalancingService:
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
or_(
|
||||
LoadBalancingModelConfig.credential_source_type == credential_source_type,
|
||||
LoadBalancingModelConfig.credential_source_type.is_(None),
|
||||
),
|
||||
)
|
||||
.order_by(LoadBalancingModelConfig.created_at)
|
||||
.all()
|
||||
@@ -405,7 +416,7 @@ class ModelLoadBalancingService:
|
||||
self._clear_credentials_cache(tenant_id, config_id)
|
||||
else:
|
||||
# create load balancing config
|
||||
if name in {"__inherit__", "__delete__"}:
|
||||
if name == "__inherit__":
|
||||
raise ValueError("Invalid load balancing config name")
|
||||
|
||||
if credential_id:
|
||||
|
@@ -72,6 +72,7 @@ class ModelProviderService:
|
||||
|
||||
provider_config = provider_configuration.custom_configuration.provider
|
||||
model_config = provider_configuration.custom_configuration.models
|
||||
can_added_models = provider_configuration.custom_configuration.can_added_models
|
||||
|
||||
provider_response = ProviderResponse(
|
||||
tenant_id=tenant_id,
|
||||
@@ -95,6 +96,7 @@ class ModelProviderService:
|
||||
current_credential_name=getattr(provider_config, "current_credential_name", None),
|
||||
available_credentials=getattr(provider_config, "available_credentials", []),
|
||||
custom_models=model_config,
|
||||
can_added_models=can_added_models,
|
||||
),
|
||||
system_configuration=SystemConfigurationResponse(
|
||||
enabled=provider_configuration.system_configuration.enabled,
|
||||
@@ -152,7 +154,7 @@ class ModelProviderService:
|
||||
provider_configuration.validate_provider_credentials(credentials)
|
||||
|
||||
def create_provider_credential(
|
||||
self, tenant_id: str, provider: str, credentials: dict, credential_name: str
|
||||
self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None
|
||||
) -> None:
|
||||
"""
|
||||
Create and save new provider credentials.
|
||||
@@ -172,7 +174,7 @@ class ModelProviderService:
|
||||
provider: str,
|
||||
credentials: dict,
|
||||
credential_id: str,
|
||||
credential_name: str,
|
||||
credential_name: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
update a saved provider credential (by credential_id).
|
||||
@@ -249,7 +251,7 @@ class ModelProviderService:
|
||||
)
|
||||
|
||||
def create_model_credential(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None
|
||||
) -> None:
|
||||
"""
|
||||
create and save model credentials.
|
||||
@@ -278,7 +280,7 @@ class ModelProviderService:
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credential_id: str,
|
||||
credential_name: str,
|
||||
credential_name: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
update model credentials.
|
||||
|
Reference in New Issue
Block a user