feat: improve multi model credentials (#25009)
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -67,7 +67,7 @@ class ModelProviderCredentialApi(Resource):
|
|||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
@@ -94,7 +94,7 @@ class ModelProviderCredentialApi(Resource):
|
|||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
|
@@ -219,7 +219,11 @@ class ModelProviderModelCredentialApi(Resource):
|
|||||||
|
|
||||||
model_load_balancing_service = ModelLoadBalancingService()
|
model_load_balancing_service = ModelLoadBalancingService()
|
||||||
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
|
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
|
||||||
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
tenant_id=tenant_id,
|
||||||
|
provider=provider,
|
||||||
|
model=args["model"],
|
||||||
|
model_type=args["model_type"],
|
||||||
|
config_from=args.get("config_from", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.get("config_from", "") == "predefined-model":
|
if args.get("config_from", "") == "predefined-model":
|
||||||
@@ -263,7 +267,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||||||
choices=[mt.value for mt in ModelType],
|
choices=[mt.value for mt in ModelType],
|
||||||
location="json",
|
location="json",
|
||||||
)
|
)
|
||||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -309,7 +313,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||||||
)
|
)
|
||||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Iterator, Sequence
|
from collections.abc import Iterator, Sequence
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
@@ -343,7 +344,65 @@ class ProviderConfiguration(BaseModel):
|
|||||||
with Session(db.engine) as new_session:
|
with Session(db.engine) as new_session:
|
||||||
return _validate(new_session)
|
return _validate(new_session)
|
||||||
|
|
||||||
def create_provider_credential(self, credentials: dict, credential_name: str) -> None:
|
def _generate_provider_credential_name(self, session) -> str:
|
||||||
|
"""
|
||||||
|
Generate a unique credential name for provider.
|
||||||
|
:return: credential name
|
||||||
|
"""
|
||||||
|
return self._generate_next_api_key_name(
|
||||||
|
session=session,
|
||||||
|
query_factory=lambda: select(ProviderCredential).where(
|
||||||
|
ProviderCredential.tenant_id == self.tenant_id,
|
||||||
|
ProviderCredential.provider_name == self.provider.provider,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str:
|
||||||
|
"""
|
||||||
|
Generate a unique credential name for custom model.
|
||||||
|
:return: credential name
|
||||||
|
"""
|
||||||
|
return self._generate_next_api_key_name(
|
||||||
|
session=session,
|
||||||
|
query_factory=lambda: select(ProviderModelCredential).where(
|
||||||
|
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||||
|
ProviderModelCredential.provider_name == self.provider.provider,
|
||||||
|
ProviderModelCredential.model_name == model,
|
||||||
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_next_api_key_name(self, session, query_factory) -> str:
|
||||||
|
"""
|
||||||
|
Generate next available API KEY name by finding the highest numbered suffix.
|
||||||
|
:param session: database session
|
||||||
|
:param query_factory: function that returns the SQLAlchemy query
|
||||||
|
:return: next available API KEY name
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
stmt = query_factory()
|
||||||
|
credential_records = session.execute(stmt).scalars().all()
|
||||||
|
|
||||||
|
if not credential_records:
|
||||||
|
return "API KEY 1"
|
||||||
|
|
||||||
|
# Extract numbers from API KEY pattern using list comprehension
|
||||||
|
pattern = re.compile(r"^API KEY\s+(\d+)$")
|
||||||
|
numbers = [
|
||||||
|
int(match.group(1))
|
||||||
|
for cr in credential_records
|
||||||
|
if cr.credential_name and (match := pattern.match(cr.credential_name.strip()))
|
||||||
|
]
|
||||||
|
|
||||||
|
# Return next sequential number
|
||||||
|
next_number = max(numbers, default=0) + 1
|
||||||
|
return f"API KEY {next_number}"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Error generating next credential name: %s", str(e))
|
||||||
|
return "API KEY 1"
|
||||||
|
|
||||||
|
def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None:
|
||||||
"""
|
"""
|
||||||
Add custom provider credentials.
|
Add custom provider credentials.
|
||||||
:param credentials: provider credentials
|
:param credentials: provider credentials
|
||||||
@@ -351,8 +410,12 @@ class ProviderConfiguration(BaseModel):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
|
if credential_name and self._check_provider_credential_name_exists(
|
||||||
|
credential_name=credential_name, session=session
|
||||||
|
):
|
||||||
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
||||||
|
else:
|
||||||
|
credential_name = self._generate_provider_credential_name(session)
|
||||||
|
|
||||||
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
|
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
|
||||||
provider_record = self._get_provider_record(session)
|
provider_record = self._get_provider_record(session)
|
||||||
@@ -395,7 +458,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
self,
|
self,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
credential_id: str,
|
credential_id: str,
|
||||||
credential_name: str,
|
credential_name: str | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
update a saved provider credential (by credential_id).
|
update a saved provider credential (by credential_id).
|
||||||
@@ -406,7 +469,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
if self._check_provider_credential_name_exists(
|
if credential_name and self._check_provider_credential_name_exists(
|
||||||
credential_name=credential_name, session=session, exclude_id=credential_id
|
credential_name=credential_name, session=session, exclude_id=credential_id
|
||||||
):
|
):
|
||||||
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
||||||
@@ -428,9 +491,9 @@ class ProviderConfiguration(BaseModel):
|
|||||||
try:
|
try:
|
||||||
# Update credential
|
# Update credential
|
||||||
credential_record.encrypted_config = json.dumps(credentials)
|
credential_record.encrypted_config = json.dumps(credentials)
|
||||||
credential_record.credential_name = credential_name
|
|
||||||
credential_record.updated_at = naive_utc_now()
|
credential_record.updated_at = naive_utc_now()
|
||||||
|
if credential_name:
|
||||||
|
credential_record.credential_name = credential_name
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
if provider_record and provider_record.credential_id == credential_id:
|
if provider_record and provider_record.credential_id == credential_id:
|
||||||
@@ -532,13 +595,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
|
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
|
||||||
)
|
)
|
||||||
lb_credentials_cache.delete()
|
lb_credentials_cache.delete()
|
||||||
|
session.delete(lb_config)
|
||||||
lb_config.credential_id = None
|
|
||||||
lb_config.encrypted_config = None
|
|
||||||
lb_config.enabled = False
|
|
||||||
lb_config.name = "__delete__"
|
|
||||||
lb_config.updated_at = naive_utc_now()
|
|
||||||
session.add(lb_config)
|
|
||||||
|
|
||||||
# Check if this is the currently active credential
|
# Check if this is the currently active credential
|
||||||
provider_record = self._get_provider_record(session)
|
provider_record = self._get_provider_record(session)
|
||||||
@@ -822,7 +879,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
return _validate(new_session)
|
return _validate(new_session)
|
||||||
|
|
||||||
def create_custom_model_credential(
|
def create_custom_model_credential(
|
||||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str
|
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Create a custom model credential.
|
Create a custom model credential.
|
||||||
@@ -833,10 +890,14 @@ class ProviderConfiguration(BaseModel):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
if self._check_custom_model_credential_name_exists(
|
if credential_name and self._check_custom_model_credential_name_exists(
|
||||||
model=model, model_type=model_type, credential_name=credential_name, session=session
|
model=model, model_type=model_type, credential_name=credential_name, session=session
|
||||||
):
|
):
|
||||||
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
|
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
|
||||||
|
else:
|
||||||
|
credential_name = self._generate_custom_model_credential_name(
|
||||||
|
model=model, model_type=model_type, session=session
|
||||||
|
)
|
||||||
# validate custom model config
|
# validate custom model config
|
||||||
credentials = self.validate_custom_model_credentials(
|
credentials = self.validate_custom_model_credentials(
|
||||||
model_type=model_type, model=model, credentials=credentials, session=session
|
model_type=model_type, model=model, credentials=credentials, session=session
|
||||||
@@ -880,7 +941,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def update_custom_model_credential(
|
def update_custom_model_credential(
|
||||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str, credential_id: str
|
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Update a custom model credential.
|
Update a custom model credential.
|
||||||
@@ -893,7 +954,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
if self._check_custom_model_credential_name_exists(
|
if credential_name and self._check_custom_model_credential_name_exists(
|
||||||
model=model,
|
model=model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
credential_name=credential_name,
|
credential_name=credential_name,
|
||||||
@@ -925,8 +986,9 @@ class ProviderConfiguration(BaseModel):
|
|||||||
try:
|
try:
|
||||||
# Update credential
|
# Update credential
|
||||||
credential_record.encrypted_config = json.dumps(credentials)
|
credential_record.encrypted_config = json.dumps(credentials)
|
||||||
credential_record.credential_name = credential_name
|
|
||||||
credential_record.updated_at = naive_utc_now()
|
credential_record.updated_at = naive_utc_now()
|
||||||
|
if credential_name:
|
||||||
|
credential_record.credential_name = credential_name
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
if provider_model_record and provider_model_record.credential_id == credential_id:
|
if provider_model_record and provider_model_record.credential_id == credential_id:
|
||||||
@@ -982,12 +1044,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
|
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
|
||||||
)
|
)
|
||||||
lb_credentials_cache.delete()
|
lb_credentials_cache.delete()
|
||||||
lb_config.credential_id = None
|
session.delete(lb_config)
|
||||||
lb_config.encrypted_config = None
|
|
||||||
lb_config.enabled = False
|
|
||||||
lb_config.name = "__delete__"
|
|
||||||
lb_config.updated_at = naive_utc_now()
|
|
||||||
session.add(lb_config)
|
|
||||||
|
|
||||||
# Check if this is the currently active credential
|
# Check if this is the currently active credential
|
||||||
provider_model_record = self._get_custom_model_record(model_type, model, session=session)
|
provider_model_record = self._get_custom_model_record(model_type, model, session=session)
|
||||||
@@ -1054,6 +1111,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
provider_name=self.provider.provider,
|
provider_name=self.provider.provider,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
model_type=model_type.to_origin_model_type(),
|
model_type=model_type.to_origin_model_type(),
|
||||||
|
is_valid=True,
|
||||||
credential_id=credential_id,
|
credential_id=credential_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -1605,11 +1663,9 @@ class ProviderConfiguration(BaseModel):
|
|||||||
if config.credential_source_type != "custom_model"
|
if config.credential_source_type != "custom_model"
|
||||||
]
|
]
|
||||||
|
|
||||||
if len(provider_model_lb_configs) > 1:
|
load_balancing_enabled = model_setting.load_balancing_enabled
|
||||||
load_balancing_enabled = True
|
# when the user enable load_balancing but available configs are less than 2 display warning
|
||||||
|
has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2
|
||||||
if any(config.name == "__delete__" for config in provider_model_lb_configs):
|
|
||||||
has_invalid_load_balancing_configs = True
|
|
||||||
|
|
||||||
provider_models.append(
|
provider_models.append(
|
||||||
ModelWithProviderEntity(
|
ModelWithProviderEntity(
|
||||||
@@ -1631,6 +1687,8 @@ class ProviderConfiguration(BaseModel):
|
|||||||
for model_configuration in self.custom_configuration.models:
|
for model_configuration in self.custom_configuration.models:
|
||||||
if model_configuration.model_type not in model_types:
|
if model_configuration.model_type not in model_types:
|
||||||
continue
|
continue
|
||||||
|
if model_configuration.unadded_to_model_list:
|
||||||
|
continue
|
||||||
if model and model != model_configuration.model:
|
if model and model != model_configuration.model:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
@@ -1663,11 +1721,9 @@ class ProviderConfiguration(BaseModel):
|
|||||||
if config.credential_source_type != "provider"
|
if config.credential_source_type != "provider"
|
||||||
]
|
]
|
||||||
|
|
||||||
if len(custom_model_lb_configs) > 1:
|
load_balancing_enabled = model_setting.load_balancing_enabled
|
||||||
load_balancing_enabled = True
|
# when the user enable load_balancing but available configs are less than 2 display warning
|
||||||
|
has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2
|
||||||
if any(config.name == "__delete__" for config in custom_model_lb_configs):
|
|
||||||
has_invalid_load_balancing_configs = True
|
|
||||||
|
|
||||||
if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials:
|
if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials:
|
||||||
status = ModelStatus.CREDENTIAL_REMOVED
|
status = ModelStatus.CREDENTIAL_REMOVED
|
||||||
|
@@ -111,11 +111,21 @@ class CustomModelConfiguration(BaseModel):
|
|||||||
current_credential_id: Optional[str] = None
|
current_credential_id: Optional[str] = None
|
||||||
current_credential_name: Optional[str] = None
|
current_credential_name: Optional[str] = None
|
||||||
available_model_credentials: list[CredentialConfiguration] = []
|
available_model_credentials: list[CredentialConfiguration] = []
|
||||||
|
unadded_to_model_list: Optional[bool] = False
|
||||||
|
|
||||||
# pydantic configs
|
# pydantic configs
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
|
class UnaddedModelConfiguration(BaseModel):
|
||||||
|
"""
|
||||||
|
Model class for provider unadded model configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
model_type: ModelType
|
||||||
|
|
||||||
|
|
||||||
class CustomConfiguration(BaseModel):
|
class CustomConfiguration(BaseModel):
|
||||||
"""
|
"""
|
||||||
Model class for provider custom configuration.
|
Model class for provider custom configuration.
|
||||||
@@ -123,6 +133,7 @@ class CustomConfiguration(BaseModel):
|
|||||||
|
|
||||||
provider: Optional[CustomProviderConfiguration] = None
|
provider: Optional[CustomProviderConfiguration] = None
|
||||||
models: list[CustomModelConfiguration] = []
|
models: list[CustomModelConfiguration] = []
|
||||||
|
can_added_models: list[UnaddedModelConfiguration] = []
|
||||||
|
|
||||||
|
|
||||||
class ModelLoadBalancingConfiguration(BaseModel):
|
class ModelLoadBalancingConfiguration(BaseModel):
|
||||||
@@ -144,6 +155,7 @@ class ModelSettings(BaseModel):
|
|||||||
model: str
|
model: str
|
||||||
model_type: ModelType
|
model_type: ModelType
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
|
load_balancing_enabled: bool = False
|
||||||
load_balancing_configs: list[ModelLoadBalancingConfiguration] = []
|
load_balancing_configs: list[ModelLoadBalancingConfiguration] = []
|
||||||
|
|
||||||
# pydantic configs
|
# pydantic configs
|
||||||
|
@@ -1,8 +1,9 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import json
|
import json
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import Sequence
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
@@ -22,6 +23,7 @@ from core.entities.provider_entities import (
|
|||||||
QuotaConfiguration,
|
QuotaConfiguration,
|
||||||
QuotaUnit,
|
QuotaUnit,
|
||||||
SystemConfiguration,
|
SystemConfiguration,
|
||||||
|
UnaddedModelConfiguration,
|
||||||
)
|
)
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||||
@@ -537,6 +539,23 @@ class ProviderManager:
|
|||||||
for credential in available_credentials
|
for credential in available_credentials
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_credentials_from_provider_model(tenant_id: str, provider_name: str) -> Sequence[ProviderModelCredential]:
|
||||||
|
"""
|
||||||
|
Get all the credentials records from ProviderModelCredential by provider_name
|
||||||
|
|
||||||
|
:param tenant_id: workspace id
|
||||||
|
:param provider_name: provider name
|
||||||
|
|
||||||
|
"""
|
||||||
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
|
stmt = select(ProviderModelCredential).where(
|
||||||
|
ProviderModelCredential.tenant_id == tenant_id, ProviderModelCredential.provider_name == provider_name
|
||||||
|
)
|
||||||
|
|
||||||
|
all_credentials = session.scalars(stmt).all()
|
||||||
|
return all_credentials
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _init_trial_provider_records(
|
def _init_trial_provider_records(
|
||||||
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
|
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
|
||||||
@@ -623,6 +642,44 @@ class ProviderManager:
|
|||||||
:param provider_model_records: provider model records
|
:param provider_model_records: provider model records
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
# Get custom provider configuration
|
||||||
|
custom_provider_configuration = self._get_custom_provider_configuration(
|
||||||
|
tenant_id, provider_entity, provider_records
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get all model credentials once
|
||||||
|
all_model_credentials = self.get_credentials_from_provider_model(tenant_id, provider_entity.provider)
|
||||||
|
|
||||||
|
# Get custom models which have not been added to the model list yet
|
||||||
|
unadded_models = self._get_can_added_models(provider_model_records, all_model_credentials)
|
||||||
|
|
||||||
|
# Get custom model configurations
|
||||||
|
custom_model_configurations = self._get_custom_model_configurations(
|
||||||
|
tenant_id, provider_entity, provider_model_records, unadded_models, all_model_credentials
|
||||||
|
)
|
||||||
|
|
||||||
|
can_added_models = [
|
||||||
|
UnaddedModelConfiguration(model=model["model"], model_type=model["model_type"]) for model in unadded_models
|
||||||
|
]
|
||||||
|
|
||||||
|
return CustomConfiguration(
|
||||||
|
provider=custom_provider_configuration,
|
||||||
|
models=custom_model_configurations,
|
||||||
|
can_added_models=can_added_models,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_custom_provider_configuration(
|
||||||
|
self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider]
|
||||||
|
) -> CustomProviderConfiguration | None:
|
||||||
|
"""Get custom provider configuration."""
|
||||||
|
# Find custom provider record (non-system)
|
||||||
|
custom_provider_record = next(
|
||||||
|
(record for record in provider_records if record.provider_type != ProviderType.SYSTEM.value), None
|
||||||
|
)
|
||||||
|
|
||||||
|
if not custom_provider_record:
|
||||||
|
return None
|
||||||
|
|
||||||
# Get provider credential secret variables
|
# Get provider credential secret variables
|
||||||
provider_credential_secret_variables = self._extract_secret_variables(
|
provider_credential_secret_variables = self._extract_secret_variables(
|
||||||
provider_entity.provider_credential_schema.credential_form_schemas
|
provider_entity.provider_credential_schema.credential_form_schemas
|
||||||
@@ -630,113 +687,98 @@ class ProviderManager:
|
|||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get custom provider record
|
# Get and decrypt provider credentials
|
||||||
custom_provider_record = None
|
provider_credentials = self._get_and_decrypt_credentials(
|
||||||
for provider_record in provider_records:
|
tenant_id=tenant_id,
|
||||||
if provider_record.provider_type == ProviderType.SYSTEM.value:
|
record_id=custom_provider_record.id,
|
||||||
continue
|
encrypted_config=custom_provider_record.encrypted_config,
|
||||||
|
secret_variables=provider_credential_secret_variables,
|
||||||
|
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||||
|
is_provider=True,
|
||||||
|
)
|
||||||
|
|
||||||
custom_provider_record = provider_record
|
return CustomProviderConfiguration(
|
||||||
|
credentials=provider_credentials,
|
||||||
|
current_credential_name=custom_provider_record.credential_name,
|
||||||
|
current_credential_id=custom_provider_record.credential_id,
|
||||||
|
available_credentials=self.get_provider_available_credentials(
|
||||||
|
tenant_id, custom_provider_record.provider_name
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# Get custom provider credentials
|
def _get_can_added_models(
|
||||||
custom_provider_configuration = None
|
self, provider_model_records: list[ProviderModel], all_model_credentials: Sequence[ProviderModelCredential]
|
||||||
if custom_provider_record:
|
) -> list[dict]:
|
||||||
provider_credentials_cache = ProviderCredentialsCache(
|
"""Get the custom models and credentials from enterprise version which haven't add to the model list"""
|
||||||
tenant_id=tenant_id,
|
existing_model_set = {(record.model_name, record.model_type) for record in provider_model_records}
|
||||||
identity_id=custom_provider_record.id,
|
|
||||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get cached provider credentials
|
# Get not added custom models credentials
|
||||||
cached_provider_credentials = provider_credentials_cache.get()
|
not_added_custom_models_credentials = [
|
||||||
|
credential
|
||||||
|
for credential in all_model_credentials
|
||||||
|
if (credential.model_name, credential.model_type) not in existing_model_set
|
||||||
|
]
|
||||||
|
|
||||||
if not cached_provider_credentials:
|
# Group credentials by model
|
||||||
try:
|
model_to_credentials = defaultdict(list)
|
||||||
# fix origin data
|
for credential in not_added_custom_models_credentials:
|
||||||
if custom_provider_record.encrypted_config is None:
|
model_to_credentials[(credential.model_name, credential.model_type)].append(credential)
|
||||||
provider_credentials = {}
|
|
||||||
elif not custom_provider_record.encrypted_config.startswith("{"):
|
|
||||||
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
|
|
||||||
else:
|
|
||||||
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
|
||||||
except JSONDecodeError:
|
|
||||||
provider_credentials = {}
|
|
||||||
|
|
||||||
# Get decoding rsa key and cipher for decrypting credentials
|
return [
|
||||||
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
{
|
||||||
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
"model": model_key[0],
|
||||||
|
"model_type": ModelType.value_of(model_key[1]),
|
||||||
|
"available_model_credentials": [
|
||||||
|
CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
|
||||||
|
for cred in creds
|
||||||
|
],
|
||||||
|
}
|
||||||
|
for model_key, creds in model_to_credentials.items()
|
||||||
|
]
|
||||||
|
|
||||||
for variable in provider_credential_secret_variables:
|
def _get_custom_model_configurations(
|
||||||
if variable in provider_credentials:
|
self,
|
||||||
with contextlib.suppress(ValueError):
|
tenant_id: str,
|
||||||
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
provider_entity: ProviderEntity,
|
||||||
provider_credentials.get(variable) or "", # type: ignore
|
provider_model_records: list[ProviderModel],
|
||||||
self.decoding_rsa_key,
|
can_added_models: list[dict],
|
||||||
self.decoding_cipher_rsa,
|
all_model_credentials: Sequence[ProviderModelCredential],
|
||||||
)
|
) -> list[CustomModelConfiguration]:
|
||||||
|
"""Get custom model configurations."""
|
||||||
# cache provider credentials
|
# Get model credential secret variables
|
||||||
provider_credentials_cache.set(credentials=provider_credentials)
|
|
||||||
else:
|
|
||||||
provider_credentials = cached_provider_credentials
|
|
||||||
|
|
||||||
custom_provider_configuration = CustomProviderConfiguration(
|
|
||||||
credentials=provider_credentials,
|
|
||||||
current_credential_name=custom_provider_record.credential_name,
|
|
||||||
current_credential_id=custom_provider_record.credential_id,
|
|
||||||
available_credentials=self.get_provider_available_credentials(
|
|
||||||
tenant_id, custom_provider_record.provider_name
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get provider model credential secret variables
|
|
||||||
model_credential_secret_variables = self._extract_secret_variables(
|
model_credential_secret_variables = self._extract_secret_variables(
|
||||||
provider_entity.model_credential_schema.credential_form_schemas
|
provider_entity.model_credential_schema.credential_form_schemas
|
||||||
if provider_entity.model_credential_schema
|
if provider_entity.model_credential_schema
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get custom provider model credentials
|
# Create credentials lookup for efficient access
|
||||||
|
credentials_map = defaultdict(list)
|
||||||
|
for credential in all_model_credentials:
|
||||||
|
credentials_map[(credential.model_name, credential.model_type)].append(credential)
|
||||||
|
|
||||||
custom_model_configurations = []
|
custom_model_configurations = []
|
||||||
|
|
||||||
|
# Process existing model records
|
||||||
for provider_model_record in provider_model_records:
|
for provider_model_record in provider_model_records:
|
||||||
available_model_credentials = self.get_provider_model_available_credentials(
|
# Use pre-fetched credentials instead of individual database calls
|
||||||
tenant_id,
|
available_model_credentials = [
|
||||||
provider_model_record.provider_name,
|
CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
|
||||||
provider_model_record.model_name,
|
for cred in credentials_map.get(
|
||||||
provider_model_record.model_type,
|
(provider_model_record.model_name, provider_model_record.model_type), []
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Get and decrypt model credentials
|
||||||
|
provider_model_credentials = self._get_and_decrypt_credentials(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
record_id=provider_model_record.id,
|
||||||
|
encrypted_config=provider_model_record.encrypted_config,
|
||||||
|
secret_variables=model_credential_secret_variables,
|
||||||
|
cache_type=ProviderCredentialsCacheType.MODEL,
|
||||||
|
is_provider=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
|
||||||
tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get cached provider model credentials
|
|
||||||
cached_provider_model_credentials = provider_model_credentials_cache.get()
|
|
||||||
|
|
||||||
if not cached_provider_model_credentials and provider_model_record.encrypted_config:
|
|
||||||
try:
|
|
||||||
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
|
|
||||||
except JSONDecodeError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get decoding rsa key and cipher for decrypting credentials
|
|
||||||
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
|
||||||
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
|
||||||
|
|
||||||
for variable in model_credential_secret_variables:
|
|
||||||
if variable in provider_model_credentials:
|
|
||||||
with contextlib.suppress(ValueError):
|
|
||||||
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
|
||||||
provider_model_credentials.get(variable),
|
|
||||||
self.decoding_rsa_key,
|
|
||||||
self.decoding_cipher_rsa,
|
|
||||||
)
|
|
||||||
|
|
||||||
# cache provider model credentials
|
|
||||||
provider_model_credentials_cache.set(credentials=provider_model_credentials)
|
|
||||||
else:
|
|
||||||
provider_model_credentials = cached_provider_model_credentials
|
|
||||||
|
|
||||||
custom_model_configurations.append(
|
custom_model_configurations.append(
|
||||||
CustomModelConfiguration(
|
CustomModelConfiguration(
|
||||||
model=provider_model_record.model_name,
|
model=provider_model_record.model_name,
|
||||||
@@ -748,7 +790,71 @@ class ProviderManager:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return CustomConfiguration(provider=custom_provider_configuration, models=custom_model_configurations)
|
# Add models that can be added
|
||||||
|
for model in can_added_models:
|
||||||
|
custom_model_configurations.append(
|
||||||
|
CustomModelConfiguration(
|
||||||
|
model=model["model"],
|
||||||
|
model_type=model["model_type"],
|
||||||
|
credentials=None,
|
||||||
|
current_credential_id=None,
|
||||||
|
current_credential_name=None,
|
||||||
|
available_model_credentials=model["available_model_credentials"],
|
||||||
|
unadded_to_model_list=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return custom_model_configurations
|
||||||
|
|
||||||
|
def _get_and_decrypt_credentials(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
record_id: str,
|
||||||
|
encrypted_config: str | None,
|
||||||
|
secret_variables: list[str],
|
||||||
|
cache_type: ProviderCredentialsCacheType,
|
||||||
|
is_provider: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""Get and decrypt credentials with caching."""
|
||||||
|
credentials_cache = ProviderCredentialsCache(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
identity_id=record_id,
|
||||||
|
cache_type=cache_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to get from cache first
|
||||||
|
cached_credentials = credentials_cache.get()
|
||||||
|
if cached_credentials:
|
||||||
|
return cached_credentials
|
||||||
|
|
||||||
|
# Parse encrypted config
|
||||||
|
if not encrypted_config:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if is_provider and not encrypted_config.startswith("{"):
|
||||||
|
return {"openai_api_key": encrypted_config}
|
||||||
|
|
||||||
|
try:
|
||||||
|
credentials = cast(dict, json.loads(encrypted_config))
|
||||||
|
except JSONDecodeError:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Decrypt secret variables
|
||||||
|
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
||||||
|
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||||
|
|
||||||
|
for variable in secret_variables:
|
||||||
|
if variable in credentials:
|
||||||
|
with contextlib.suppress(ValueError):
|
||||||
|
credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||||
|
credentials.get(variable) or "",
|
||||||
|
self.decoding_rsa_key,
|
||||||
|
self.decoding_cipher_rsa,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache the decrypted credentials
|
||||||
|
credentials_cache.set(credentials=credentials)
|
||||||
|
return credentials
|
||||||
|
|
||||||
def _to_system_configuration(
|
def _to_system_configuration(
|
||||||
self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider]
|
self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider]
|
||||||
@@ -956,18 +1062,6 @@ class ProviderManager:
|
|||||||
load_balancing_model_config.model_name == provider_model_setting.model_name
|
load_balancing_model_config.model_name == provider_model_setting.model_name
|
||||||
and load_balancing_model_config.model_type == provider_model_setting.model_type
|
and load_balancing_model_config.model_type == provider_model_setting.model_type
|
||||||
):
|
):
|
||||||
if load_balancing_model_config.name == "__delete__":
|
|
||||||
# to calculate current model whether has invalidate lb configs
|
|
||||||
load_balancing_configs.append(
|
|
||||||
ModelLoadBalancingConfiguration(
|
|
||||||
id=load_balancing_model_config.id,
|
|
||||||
name=load_balancing_model_config.name,
|
|
||||||
credentials={},
|
|
||||||
credential_source_type=load_balancing_model_config.credential_source_type,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not load_balancing_model_config.enabled:
|
if not load_balancing_model_config.enabled:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -1033,6 +1127,7 @@ class ProviderManager:
|
|||||||
model=provider_model_setting.model_name,
|
model=provider_model_setting.model_name,
|
||||||
model_type=ModelType.value_of(provider_model_setting.model_type),
|
model_type=ModelType.value_of(provider_model_setting.model_type),
|
||||||
enabled=provider_model_setting.enabled,
|
enabled=provider_model_setting.enabled,
|
||||||
|
load_balancing_enabled=provider_model_setting.load_balancing_enabled,
|
||||||
load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [],
|
load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@@ -13,6 +13,7 @@ from core.entities.provider_entities import (
|
|||||||
CustomModelConfiguration,
|
CustomModelConfiguration,
|
||||||
ProviderQuotaType,
|
ProviderQuotaType,
|
||||||
QuotaConfiguration,
|
QuotaConfiguration,
|
||||||
|
UnaddedModelConfiguration,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
@@ -45,6 +46,7 @@ class CustomConfigurationResponse(BaseModel):
|
|||||||
current_credential_name: Optional[str] = None
|
current_credential_name: Optional[str] = None
|
||||||
available_credentials: Optional[list[CredentialConfiguration]] = None
|
available_credentials: Optional[list[CredentialConfiguration]] = None
|
||||||
custom_models: Optional[list[CustomModelConfiguration]] = None
|
custom_models: Optional[list[CustomModelConfiguration]] = None
|
||||||
|
can_added_models: Optional[list[UnaddedModelConfiguration]] = None
|
||||||
|
|
||||||
|
|
||||||
class SystemConfigurationResponse(BaseModel):
|
class SystemConfigurationResponse(BaseModel):
|
||||||
|
@@ -3,6 +3,8 @@ import logging
|
|||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from sqlalchemy import or_
|
||||||
|
|
||||||
from constants import HIDDEN_VALUE
|
from constants import HIDDEN_VALUE
|
||||||
from core.entities.provider_configuration import ProviderConfiguration
|
from core.entities.provider_configuration import ProviderConfiguration
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
@@ -69,7 +71,7 @@ class ModelLoadBalancingService:
|
|||||||
provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
|
provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
|
||||||
|
|
||||||
def get_load_balancing_configs(
|
def get_load_balancing_configs(
|
||||||
self, tenant_id: str, provider: str, model: str, model_type: str
|
self, tenant_id: str, provider: str, model: str, model_type: str, config_from: str = ""
|
||||||
) -> tuple[bool, list[dict]]:
|
) -> tuple[bool, list[dict]]:
|
||||||
"""
|
"""
|
||||||
Get load balancing configurations.
|
Get load balancing configurations.
|
||||||
@@ -100,6 +102,11 @@ class ModelLoadBalancingService:
|
|||||||
if provider_model_setting and provider_model_setting.load_balancing_enabled:
|
if provider_model_setting and provider_model_setting.load_balancing_enabled:
|
||||||
is_load_balancing_enabled = True
|
is_load_balancing_enabled = True
|
||||||
|
|
||||||
|
if config_from == "predefined-model":
|
||||||
|
credential_source_type = "provider"
|
||||||
|
else:
|
||||||
|
credential_source_type = "custom_model"
|
||||||
|
|
||||||
# Get load balancing configurations
|
# Get load balancing configurations
|
||||||
load_balancing_configs = (
|
load_balancing_configs = (
|
||||||
db.session.query(LoadBalancingModelConfig)
|
db.session.query(LoadBalancingModelConfig)
|
||||||
@@ -108,6 +115,10 @@ class ModelLoadBalancingService:
|
|||||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||||
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
||||||
LoadBalancingModelConfig.model_name == model,
|
LoadBalancingModelConfig.model_name == model,
|
||||||
|
or_(
|
||||||
|
LoadBalancingModelConfig.credential_source_type == credential_source_type,
|
||||||
|
LoadBalancingModelConfig.credential_source_type.is_(None),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
.order_by(LoadBalancingModelConfig.created_at)
|
.order_by(LoadBalancingModelConfig.created_at)
|
||||||
.all()
|
.all()
|
||||||
@@ -405,7 +416,7 @@ class ModelLoadBalancingService:
|
|||||||
self._clear_credentials_cache(tenant_id, config_id)
|
self._clear_credentials_cache(tenant_id, config_id)
|
||||||
else:
|
else:
|
||||||
# create load balancing config
|
# create load balancing config
|
||||||
if name in {"__inherit__", "__delete__"}:
|
if name == "__inherit__":
|
||||||
raise ValueError("Invalid load balancing config name")
|
raise ValueError("Invalid load balancing config name")
|
||||||
|
|
||||||
if credential_id:
|
if credential_id:
|
||||||
|
@@ -72,6 +72,7 @@ class ModelProviderService:
|
|||||||
|
|
||||||
provider_config = provider_configuration.custom_configuration.provider
|
provider_config = provider_configuration.custom_configuration.provider
|
||||||
model_config = provider_configuration.custom_configuration.models
|
model_config = provider_configuration.custom_configuration.models
|
||||||
|
can_added_models = provider_configuration.custom_configuration.can_added_models
|
||||||
|
|
||||||
provider_response = ProviderResponse(
|
provider_response = ProviderResponse(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
@@ -95,6 +96,7 @@ class ModelProviderService:
|
|||||||
current_credential_name=getattr(provider_config, "current_credential_name", None),
|
current_credential_name=getattr(provider_config, "current_credential_name", None),
|
||||||
available_credentials=getattr(provider_config, "available_credentials", []),
|
available_credentials=getattr(provider_config, "available_credentials", []),
|
||||||
custom_models=model_config,
|
custom_models=model_config,
|
||||||
|
can_added_models=can_added_models,
|
||||||
),
|
),
|
||||||
system_configuration=SystemConfigurationResponse(
|
system_configuration=SystemConfigurationResponse(
|
||||||
enabled=provider_configuration.system_configuration.enabled,
|
enabled=provider_configuration.system_configuration.enabled,
|
||||||
@@ -152,7 +154,7 @@ class ModelProviderService:
|
|||||||
provider_configuration.validate_provider_credentials(credentials)
|
provider_configuration.validate_provider_credentials(credentials)
|
||||||
|
|
||||||
def create_provider_credential(
|
def create_provider_credential(
|
||||||
self, tenant_id: str, provider: str, credentials: dict, credential_name: str
|
self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Create and save new provider credentials.
|
Create and save new provider credentials.
|
||||||
@@ -172,7 +174,7 @@ class ModelProviderService:
|
|||||||
provider: str,
|
provider: str,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
credential_id: str,
|
credential_id: str,
|
||||||
credential_name: str,
|
credential_name: str | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
update a saved provider credential (by credential_id).
|
update a saved provider credential (by credential_id).
|
||||||
@@ -249,7 +251,7 @@ class ModelProviderService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_model_credential(
|
def create_model_credential(
|
||||||
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str
|
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
create and save model credentials.
|
create and save model credentials.
|
||||||
@@ -278,7 +280,7 @@ class ModelProviderService:
|
|||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
credential_id: str,
|
credential_id: str,
|
||||||
credential_name: str,
|
credential_name: str | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
update model credentials.
|
update model credentials.
|
||||||
|
Reference in New Issue
Block a user