diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 281783b3d..3861fb8e9 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -10,6 +10,7 @@ from controllers.console.wraps import account_initialization_required, setup_req from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder +from libs.helper import StrLen, uuid_value from libs.login import login_required from services.billing_service import BillingService from services.model_provider_service import ModelProviderService @@ -45,12 +46,109 @@ class ModelProviderCredentialApi(Resource): @account_initialization_required def get(self, provider: str): tenant_id = current_user.current_tenant_id + # if credential_id is not provided, return current used credential + parser = reqparse.RequestParser() + parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args") + args = parser.parse_args() model_provider_service = ModelProviderService() - credentials = model_provider_service.get_provider_credentials(tenant_id=tenant_id, provider=provider) + credentials = model_provider_service.get_provider_credential( + tenant_id=tenant_id, provider=provider, credential_id=args.get("credential_id") + ) return {"credentials": credentials} + @setup_required + @login_required + @account_initialization_required + def post(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + args = parser.parse_args() + + model_provider_service = ModelProviderService() + + try: + model_provider_service.create_provider_credential( + tenant_id=current_user.current_tenant_id, + provider=provider, + credentials=args["credentials"], + credential_name=args["name"], + ) + except CredentialsValidateFailedError as ex: + raise ValueError(str(ex)) + + return {"result": "success"}, 201 + + @setup_required + @login_required + @account_initialization_required + def put(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + args = parser.parse_args() + + model_provider_service = ModelProviderService() + + try: + model_provider_service.update_provider_credential( + tenant_id=current_user.current_tenant_id, + provider=provider, + credentials=args["credentials"], + credential_id=args["credential_id"], + credential_name=args["name"], + ) + except CredentialsValidateFailedError as ex: + raise ValueError(str(ex)) + + return {"result": "success"} + + @setup_required + @login_required + @account_initialization_required + def delete(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") + args = parser.parse_args() + + model_provider_service = ModelProviderService() + model_provider_service.remove_provider_credential( + tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"] + ) + + return {"result": "success"}, 204 + + +class ModelProviderCredentialSwitchApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + service = ModelProviderService() + service.switch_active_provider_credential( + tenant_id=current_user.current_tenant_id, + provider=provider, + credential_id=args["credential_id"], + ) + return {"result": "success"} + class ModelProviderValidateApi(Resource): @setup_required @@ -69,7 +167,7 @@ class ModelProviderValidateApi(Resource): error = "" try: - model_provider_service.provider_credentials_validate( + model_provider_service.validate_provider_credentials( tenant_id=tenant_id, provider=provider, credentials=args["credentials"] ) except CredentialsValidateFailedError as ex: @@ -84,42 +182,6 @@ class ModelProviderValidateApi(Resource): return response -class ModelProviderApi(Resource): - @setup_required - @login_required - @account_initialization_required - def post(self, provider: str): - if not current_user.is_admin_or_owner: - raise Forbidden() - - parser = reqparse.RequestParser() - parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") - args = parser.parse_args() - - model_provider_service = ModelProviderService() - - try: - model_provider_service.save_provider_credentials( - tenant_id=current_user.current_tenant_id, provider=provider, credentials=args["credentials"] - ) - except CredentialsValidateFailedError as ex: - raise ValueError(str(ex)) - - return {"result": "success"}, 201 - - @setup_required - @login_required - @account_initialization_required - def delete(self, provider: str): - if not current_user.is_admin_or_owner: - raise Forbidden() - - model_provider_service = ModelProviderService() - model_provider_service.remove_provider_credentials(tenant_id=current_user.current_tenant_id, provider=provider) - - return {"result": "success"}, 204 - - class ModelProviderIconApi(Resource): """ Get model provider icon @@ -187,8 +249,10 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers//credentials") +api.add_resource( + ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers//credentials/switch" +) api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers//credentials/validate") -api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/") api.add_resource( PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers//preferred-provider-type" diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index b8dddb91d..98702dd3b 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -9,6 +9,7 @@ from controllers.console.wraps import account_initialization_required, setup_req from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder +from libs.helper import StrLen, uuid_value from libs.login import login_required from services.model_load_balancing_service import ModelLoadBalancingService from services.model_provider_service import ModelProviderService @@ -98,6 +99,7 @@ class ModelProviderModelApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + # To save the model's load balance configs if not current_user.is_admin_or_owner: raise Forbidden() @@ -113,22 +115,26 @@ class ModelProviderModelApi(Resource): choices=[mt.value for mt in ModelType], location="json", ) - parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json") parser.add_argument("config_from", type=str, required=False, nullable=True, location="json") + parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json") args = parser.parse_args() + if args.get("config_from", "") == "custom-model": + if not args.get("credential_id"): + raise ValueError("credential_id is required when configuring a custom-model") + service = ModelProviderService() + service.switch_active_custom_model_credential( + tenant_id=current_user.current_tenant_id, + provider=provider, + model_type=args["model_type"], + model=args["model"], + credential_id=args["credential_id"], + ) + model_load_balancing_service = ModelLoadBalancingService() - if ( - "load_balancing" in args - and args["load_balancing"] - and "enabled" in args["load_balancing"] - and args["load_balancing"]["enabled"] - ): - if "configs" not in args["load_balancing"]: - raise ValueError("invalid load balancing configs") - + if "load_balancing" in args and args["load_balancing"] and "configs" in args["load_balancing"]: # save load balancing configs model_load_balancing_service.update_load_balancing_configs( tenant_id=tenant_id, @@ -136,37 +142,17 @@ class ModelProviderModelApi(Resource): model=args["model"], model_type=args["model_type"], configs=args["load_balancing"]["configs"], + config_from=args.get("config_from", ""), ) - # enable load balancing - model_load_balancing_service.enable_model_load_balancing( - tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] - ) - else: - # disable load balancing - model_load_balancing_service.disable_model_load_balancing( - tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] - ) - - if args.get("config_from", "") != "predefined-model": - model_provider_service = ModelProviderService() - - try: - model_provider_service.save_model_credentials( - tenant_id=tenant_id, - provider=provider, - model=args["model"], - model_type=args["model_type"], - credentials=args["credentials"], - ) - except CredentialsValidateFailedError as ex: - logging.exception( - "Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s", - tenant_id, - args.get("model"), - args.get("model_type"), - ) - raise ValueError(str(ex)) + if args.get("load_balancing", {}).get("enabled"): + model_load_balancing_service.enable_model_load_balancing( + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] + ) + else: + model_load_balancing_service.disable_model_load_balancing( + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] + ) return {"result": "success"}, 200 @@ -192,7 +178,7 @@ class ModelProviderModelApi(Resource): args = parser.parse_args() model_provider_service = ModelProviderService() - model_provider_service.remove_model_credentials( + model_provider_service.remove_model( tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) @@ -216,11 +202,17 @@ class ModelProviderModelCredentialApi(Resource): choices=[mt.value for mt in ModelType], location="args", ) + parser.add_argument("config_from", type=str, required=False, nullable=True, location="args") + parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args") args = parser.parse_args() model_provider_service = ModelProviderService() - credentials = model_provider_service.get_model_credentials( - tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"] + current_credential = model_provider_service.get_model_credential( + tenant_id=tenant_id, + provider=provider, + model_type=args["model_type"], + model=args["model"], + credential_id=args.get("credential_id"), ) model_load_balancing_service = ModelLoadBalancingService() @@ -228,10 +220,173 @@ class ModelProviderModelCredentialApi(Resource): tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - return { - "credentials": credentials, - "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs}, - } + if args.get("config_from", "") == "predefined-model": + available_credentials = model_provider_service.provider_manager.get_provider_available_credentials( + tenant_id=tenant_id, provider_name=provider + ) + else: + model_type = ModelType.value_of(args["model_type"]).to_origin_model_type() + available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials( + tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args["model"] + ) + + return jsonable_encoder( + { + "credentials": current_credential.get("credentials") if current_credential else {}, + "current_credential_id": current_credential.get("current_credential_id") + if current_credential + else None, + "current_credential_name": current_credential.get("current_credential_name") + if current_credential + else None, + "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs}, + "available_credentials": available_credentials, + } + ) + + @setup_required + @login_required + @account_initialization_required + def post(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + + tenant_id = current_user.current_tenant_id + model_provider_service = ModelProviderService() + + try: + model_provider_service.create_model_credential( + tenant_id=tenant_id, + provider=provider, + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], + credential_name=args["name"], + ) + except CredentialsValidateFailedError as ex: + logging.exception( + "Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s", + tenant_id, + args.get("model"), + args.get("model_type"), + ) + raise ValueError(str(ex)) + + return {"result": "success"}, 201 + + @setup_required + @login_required + @account_initialization_required + def put(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + args = parser.parse_args() + + model_provider_service = ModelProviderService() + + try: + model_provider_service.update_model_credential( + tenant_id=current_user.current_tenant_id, + provider=provider, + model_type=args["model_type"], + model=args["model"], + credentials=args["credentials"], + credential_id=args["credential_id"], + credential_name=args["name"], + ) + except CredentialsValidateFailedError as ex: + raise ValueError(str(ex)) + + return {"result": "success"} + + @setup_required + @login_required + @account_initialization_required + def delete(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") + args = parser.parse_args() + + model_provider_service = ModelProviderService() + model_provider_service.remove_model_credential( + tenant_id=current_user.current_tenant_id, + provider=provider, + model_type=args["model_type"], + model=args["model"], + credential_id=args["credential_id"], + ) + + return {"result": "success"}, 204 + + +class ModelProviderModelCredentialSwitchApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider: str): + if not current_user.is_admin_or_owner: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + service = ModelProviderService() + service.add_model_credential_to_model_list( + tenant_id=current_user.current_tenant_id, + provider=provider, + model_type=args["model_type"], + model=args["model"], + credential_id=args["credential_id"], + ) + return {"result": "success"} class ModelProviderModelEnableApi(Resource): @@ -314,7 +469,7 @@ class ModelProviderModelValidateApi(Resource): error = "" try: - model_provider_service.model_credentials_validate( + model_provider_service.validate_model_credentials( tenant_id=tenant_id, provider=provider, model=args["model"], @@ -379,6 +534,10 @@ api.add_resource( api.add_resource( ModelProviderModelCredentialApi, "/workspaces/current/model-providers//models/credentials" ) +api.add_resource( + ModelProviderModelCredentialSwitchApi, + "/workspaces/current/model-providers//models/credentials/switch", +) api.add_resource( ModelProviderModelValidateApi, "/workspaces/current/model-providers//models/credentials/validate" ) diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index e1c021a44..ac64a8e3a 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -19,6 +19,7 @@ class ModelStatus(Enum): QUOTA_EXCEEDED = "quota-exceeded" NO_PERMISSION = "no-permission" DISABLED = "disabled" + CREDENTIAL_REMOVED = "credential-removed" class SimpleModelProviderEntity(BaseModel): @@ -54,6 +55,7 @@ class ProviderModelWithStatusEntity(ProviderModel): status: ModelStatus load_balancing_enabled: bool = False + has_invalid_load_balancing_configs: bool = False def raise_for_status(self) -> None: """ diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 646e0e21e..ca3c36b87 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -6,6 +6,8 @@ from json import JSONDecodeError from typing import Optional from pydantic import BaseModel, ConfigDict, Field +from sqlalchemy import func, select +from sqlalchemy.orm import Session from constants import HIDDEN_VALUE from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity @@ -32,7 +34,9 @@ from libs.datetime_utils import naive_utc_now from models.provider import ( LoadBalancingModelConfig, Provider, + ProviderCredential, ProviderModel, + ProviderModelCredential, ProviderModelSetting, ProviderType, TenantPreferredModelProvider, @@ -45,7 +49,16 @@ original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {} class ProviderConfiguration(BaseModel): """ - Model class for provider configuration. + Provider configuration entity for managing model provider settings. + + This class handles: + - Provider credentials CRUD and switch + - Custom Model credentials CRUD and switch + - System vs custom provider switching + - Load balancing configurations + - Model enablement/disablement + + TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified """ tenant_id: str @@ -155,33 +168,17 @@ class ProviderConfiguration(BaseModel): Check custom configuration available. :return: """ - return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 - - def get_custom_credentials(self, obfuscated: bool = False) -> dict | None: - """ - Get custom credentials. - - :param obfuscated: obfuscated secret data in credentials - :return: - """ - if self.custom_configuration.provider is None: - return None - - credentials = self.custom_configuration.provider.credentials - if not obfuscated: - return credentials - - # Obfuscate credentials - return self.obfuscated_credentials( - credentials=credentials, - credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas - if self.provider.provider_credential_schema - else [], + has_provider_credentials = ( + self.custom_configuration.provider is not None + and len(self.custom_configuration.provider.available_credentials) > 0 ) - def _get_custom_provider_credentials(self) -> Provider | None: + has_model_configurations = len(self.custom_configuration.models) > 0 + return has_provider_credentials or has_model_configurations + + def _get_provider_record(self, session: Session) -> Provider | None: """ - Get custom provider credentials. + Get custom provider record. """ # get provider model_provider_id = ModelProviderID(self.provider.provider) @@ -189,156 +186,442 @@ class ProviderConfiguration(BaseModel): if model_provider_id.is_langgenius(): provider_names.append(model_provider_id.provider_name) - provider_record = ( - db.session.query(Provider) - .where( - Provider.tenant_id == self.tenant_id, - Provider.provider_type == ProviderType.CUSTOM.value, - Provider.provider_name.in_(provider_names), - ) - .first() + stmt = select(Provider).where( + Provider.tenant_id == self.tenant_id, + Provider.provider_type == ProviderType.CUSTOM.value, + Provider.provider_name.in_(provider_names), ) - return provider_record + return session.execute(stmt).scalar_one_or_none() - def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]: + def _get_specific_provider_credential(self, credential_id: str) -> dict | None: """ - Validate custom credentials. - :param credentials: provider credentials + Get a specific provider credential by ID. + :param credential_id: Credential ID :return: """ - provider_record = self._get_custom_provider_credentials() - - # Get provider credential secret variables - provider_credential_secret_variables = self.extract_secret_variables( + # Extract secret variables from provider credential schema + credential_secret_variables = self.extract_secret_variables( self.provider.provider_credential_schema.credential_form_schemas if self.provider.provider_credential_schema else [] ) - if provider_record: - try: - # fix origin data - if provider_record.encrypted_config: - if not provider_record.encrypted_config.startswith("{"): - original_credentials = {"openai_api_key": provider_record.encrypted_config} - else: - original_credentials = json.loads(provider_record.encrypted_config) - else: - original_credentials = {} - except JSONDecodeError: - original_credentials = {} + with Session(db.engine) as session: + # Prefer the actual provider record name if exists (to handle aliased provider names) + provider_record = self._get_provider_record(session) + provider_name = provider_record.provider_name if provider_record else self.provider.provider - # encrypt credentials - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - if value == HIDDEN_VALUE and key in original_credentials: - credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) - - model_provider_factory = ModelProviderFactory(self.tenant_id) - credentials = model_provider_factory.provider_credentials_validate( - provider=self.provider.provider, credentials=credentials - ) - - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - credentials[key] = encrypter.encrypt_token(self.tenant_id, value) - - return provider_record, credentials - - def add_or_update_custom_credentials(self, credentials: dict) -> None: - """ - Add or update custom provider credentials. - :param credentials: - :return: - """ - # validate custom provider config - provider_record, credentials = self.custom_credentials_validate(credentials) - - # save provider - # Note: Do not switch the preferred provider, which allows users to use quotas first - if provider_record: - provider_record.encrypted_config = json.dumps(credentials) - provider_record.is_valid = True - provider_record.updated_at = naive_utc_now() - db.session.commit() - else: - provider_record = Provider() - provider_record.tenant_id = self.tenant_id - provider_record.provider_name = self.provider.provider - provider_record.provider_type = ProviderType.CUSTOM.value - provider_record.encrypted_config = json.dumps(credentials) - provider_record.is_valid = True - - db.session.add(provider_record) - db.session.commit() - - provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER - ) - - provider_model_credentials_cache.delete() - - self.switch_preferred_provider_type(ProviderType.CUSTOM) - - def delete_custom_credentials(self) -> None: - """ - Delete custom provider credentials. - :return: - """ - # get provider - provider_record = self._get_custom_provider_credentials() - - # delete provider - if provider_record: - self.switch_preferred_provider_type(ProviderType.SYSTEM) - - db.session.delete(provider_record) - db.session.commit() - - provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER, + stmt = select(ProviderCredential).where( + ProviderCredential.id == credential_id, + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == provider_name, ) - provider_model_credentials_cache.delete() + credential = session.execute(stmt).scalar_one_or_none() - def get_custom_model_credentials( - self, model_type: ModelType, model: str, obfuscated: bool = False - ) -> Optional[dict]: + if not credential or not credential.encrypted_config: + raise ValueError(f"Credential with id {credential_id} not found.") + + try: + credentials = json.loads(credential.encrypted_config) + except JSONDecodeError: + credentials = {} + + # Decrypt secret variables + for key in credential_secret_variables: + if key in credentials and credentials[key] is not None: + try: + credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key]) + except Exception: + pass + + return self.obfuscated_credentials( + credentials=credentials, + credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas + if self.provider.provider_credential_schema + else [], + ) + + def _check_provider_credential_name_exists( + self, credential_name: str, session: Session, exclude_id: str | None = None + ) -> bool: """ - Get custom model credentials. + not allowed same name when create or update a credential + """ + stmt = select(ProviderCredential.id).where( + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == self.provider.provider, + ProviderCredential.credential_name == credential_name, + ) + if exclude_id: + stmt = stmt.where(ProviderCredential.id != exclude_id) + return session.execute(stmt).scalar_one_or_none() is not None - :param model_type: model type - :param model: model name - :param obfuscated: obfuscated secret data in credentials + def get_provider_credential(self, credential_id: str | None = None) -> dict | None: + """ + Get provider credentials. + + :param credential_id: if provided, return the specified credential :return: """ - if not self.custom_configuration.models: - return None - for model_configuration in self.custom_configuration.models: - if model_configuration.model_type == model_type and model_configuration.model == model: - credentials = model_configuration.credentials - if not obfuscated: - return credentials + if credential_id: + return self._get_specific_provider_credential(credential_id) - # Obfuscate credentials - return self.obfuscated_credentials( - credentials=credentials, - credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas - if self.provider.model_credential_schema - else [], + # Default behavior: return current active provider credentials + credentials = self.custom_configuration.provider.credentials if self.custom_configuration.provider else {} + + return self.obfuscated_credentials( + credentials=credentials, + credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas + if self.provider.provider_credential_schema + else [], + ) + + def validate_provider_credentials( + self, credentials: dict, credential_id: str = "", session: Session | None = None + ) -> dict: + """ + Validate custom credentials. + :param credentials: provider credentials + :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate + :param session: optional database session + :return: + """ + + def _validate(s: Session) -> dict: + # Get provider credential secret variables + provider_credential_secret_variables = self.extract_secret_variables( + self.provider.provider_credential_schema.credential_form_schemas + if self.provider.provider_credential_schema + else [] + ) + + if credential_id: + try: + stmt = select(ProviderCredential).where( + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == self.provider.provider, + ProviderCredential.id == credential_id, + ) + credential_record = s.execute(stmt).scalar_one_or_none() + # fix origin data + if credential_record and credential_record.encrypted_config: + if not credential_record.encrypted_config.startswith("{"): + original_credentials = {"openai_api_key": credential_record.encrypted_config} + else: + original_credentials = json.loads(credential_record.encrypted_config) + else: + original_credentials = {} + except JSONDecodeError: + original_credentials = {} + + # encrypt credentials + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + if value == HIDDEN_VALUE and key in original_credentials: + credentials[key] = encrypter.decrypt_token( + tenant_id=self.tenant_id, token=original_credentials[key] + ) + + model_provider_factory = ModelProviderFactory(self.tenant_id) + validated_credentials = model_provider_factory.provider_credentials_validate( + provider=self.provider.provider, credentials=credentials + ) + + for key, value in validated_credentials.items(): + if key in provider_credential_secret_variables: + validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + + return validated_credentials + + if session: + return _validate(session) + else: + with Session(db.engine) as new_session: + return _validate(new_session) + + def create_provider_credential(self, credentials: dict, credential_name: str) -> None: + """ + Add custom provider credentials. + :param credentials: provider credentials + :param credential_name: credential name + :return: + """ + with Session(db.engine) as session: + if self._check_provider_credential_name_exists(credential_name=credential_name, session=session): + raise ValueError(f"Credential with name '{credential_name}' already exists.") + + credentials = self.validate_provider_credentials(credentials=credentials, session=session) + provider_record = self._get_provider_record(session) + try: + new_record = ProviderCredential( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + encrypted_config=json.dumps(credentials), + credential_name=credential_name, ) + session.add(new_record) + session.flush() - return None + if not provider_record: + # If provider record does not exist, create it + provider_record = Provider( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + provider_type=ProviderType.CUSTOM.value, + is_valid=True, + credential_id=new_record.id, + ) + session.add(provider_record) - def _get_custom_model_credentials( + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ) + provider_model_credentials_cache.delete() + + self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session) + + session.commit() + except Exception: + session.rollback() + raise + + def update_provider_credential( + self, + credentials: dict, + credential_id: str, + credential_name: str, + ) -> None: + """ + update a saved provider credential (by credential_id). + + :param credentials: provider credentials + :param credential_id: credential id + :param credential_name: credential name + :return: + """ + with Session(db.engine) as session: + if self._check_provider_credential_name_exists( + credential_name=credential_name, session=session, exclude_id=credential_id + ): + raise ValueError(f"Credential with name '{credential_name}' already exists.") + + credentials = self.validate_provider_credentials( + credentials=credentials, credential_id=credential_id, session=session + ) + provider_record = self._get_provider_record(session) + stmt = select(ProviderCredential).where( + ProviderCredential.id == credential_id, + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == self.provider.provider, + ) + + # Get the credential record to update + credential_record = session.execute(stmt).scalar_one_or_none() + if not credential_record: + raise ValueError("Credential record not found.") + try: + # Update credential + credential_record.encrypted_config = json.dumps(credentials) + credential_record.credential_name = credential_name + credential_record.updated_at = naive_utc_now() + + session.commit() + + if provider_record and provider_record.credential_id == credential_id: + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ) + provider_model_credentials_cache.delete() + + self._update_load_balancing_configs_with_credential( + credential_id=credential_id, + credential_record=credential_record, + credential_source="provider", + session=session, + ) + except Exception: + session.rollback() + raise + + def _update_load_balancing_configs_with_credential( + self, + credential_id: str, + credential_record: ProviderCredential | ProviderModelCredential, + credential_source: str, + session: Session, + ) -> None: + """ + Update load balancing configurations that reference the given credential_id. + + :param credential_id: credential id + :param credential_record: the encrypted_config and credential_name + :param credential_source: the credential comes from the provider_credential(`provider`) + or the provider_model_credential(`custom_model`) + :param session: the database session + :return: + """ + # Find all load balancing configs that use this credential_id + stmt = select(LoadBalancingModelConfig).where( + LoadBalancingModelConfig.tenant_id == self.tenant_id, + LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.credential_id == credential_id, + LoadBalancingModelConfig.credential_source_type == credential_source, + ) + load_balancing_configs = session.execute(stmt).scalars().all() + + if not load_balancing_configs: + return + + # Update each load balancing config with the new credentials + for lb_config in load_balancing_configs: + # Update the encrypted_config with the new credentials + lb_config.encrypted_config = credential_record.encrypted_config + lb_config.name = credential_record.credential_name + lb_config.updated_at = naive_utc_now() + + # Clear cache for this load balancing config + lb_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=lb_config.id, + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, + ) + lb_credentials_cache.delete() + + session.commit() + + def delete_provider_credential(self, credential_id: str) -> None: + """ + Delete a saved provider credential (by credential_id). + + :param credential_id: credential id + :return: + """ + with Session(db.engine) as session: + stmt = select(ProviderCredential).where( + ProviderCredential.id == credential_id, + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == self.provider.provider, + ) + + # Get the credential record to update + credential_record = session.execute(stmt).scalar_one_or_none() + if not credential_record: + raise ValueError("Credential record not found.") + + # Check if this credential is used in load balancing configs + lb_stmt = select(LoadBalancingModelConfig).where( + LoadBalancingModelConfig.tenant_id == self.tenant_id, + LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.credential_id == credential_id, + LoadBalancingModelConfig.credential_source_type == "provider", + ) + lb_configs_using_credential = session.execute(lb_stmt).scalars().all() + try: + for lb_config in lb_configs_using_credential: + lb_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=lb_config.id, + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, + ) + lb_credentials_cache.delete() + + lb_config.credential_id = None + lb_config.encrypted_config = None + lb_config.enabled = False + lb_config.name = "__delete__" + lb_config.updated_at = naive_utc_now() + session.add(lb_config) + + # Check if this is the currently active credential + provider_record = self._get_provider_record(session) + + # Check available credentials count BEFORE deleting + # if this is the last credential, we need to delete the provider record + count_stmt = select(func.count(ProviderCredential.id)).where( + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == self.provider.provider, + ) + available_credentials_count = session.execute(count_stmt).scalar() or 0 + session.delete(credential_record) + + if provider_record and available_credentials_count <= 1: + # If all credentials are deleted, delete the provider record + session.delete(provider_record) + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ) + provider_model_credentials_cache.delete() + self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session) + elif provider_record and provider_record.credential_id == credential_id: + provider_record.credential_id = None + provider_record.updated_at = naive_utc_now() + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ) + provider_model_credentials_cache.delete() + self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session) + + session.commit() + except Exception: + session.rollback() + raise + + def switch_active_provider_credential(self, credential_id: str) -> None: + """ + Switch active provider credential (copy the selected one into current active snapshot). + + :param credential_id: credential id + :return: + """ + with Session(db.engine) as session: + stmt = select(ProviderCredential).where( + ProviderCredential.id == credential_id, + ProviderCredential.tenant_id == self.tenant_id, + ProviderCredential.provider_name == self.provider.provider, + ) + credential_record = session.execute(stmt).scalar_one_or_none() + if not credential_record: + raise ValueError("Credential record not found.") + + provider_record = self._get_provider_record(session) + if not provider_record: + raise ValueError("Provider record not found.") + + try: + provider_record.credential_id = credential_record.id + provider_record.updated_at = naive_utc_now() + session.commit() + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ) + provider_model_credentials_cache.delete() + self.switch_preferred_provider_type(ProviderType.CUSTOM, session=session) + except Exception: + session.rollback() + raise + + def _get_custom_model_record( self, model_type: ModelType, model: str, + session: Session, ) -> ProviderModel | None: """ Get custom model credentials. @@ -349,128 +632,495 @@ class ProviderConfiguration(BaseModel): if model_provider_id.is_langgenius(): provider_names.append(model_provider_id.provider_name) - provider_model_record = ( - db.session.query(ProviderModel) - .where( - ProviderModel.tenant_id == self.tenant_id, - ProviderModel.provider_name.in_(provider_names), - ProviderModel.model_name == model, - ProviderModel.model_type == model_type.to_origin_model_type(), - ) - .first() + stmt = select(ProviderModel).where( + ProviderModel.tenant_id == self.tenant_id, + ProviderModel.provider_name.in_(provider_names), + ProviderModel.model_name == model, + ProviderModel.model_type == model_type.to_origin_model_type(), ) - return provider_model_record + return session.execute(stmt).scalar_one_or_none() - def custom_model_credentials_validate( - self, model_type: ModelType, model: str, credentials: dict - ) -> tuple[ProviderModel | None, dict]: + def _get_specific_custom_model_credential( + self, model_type: ModelType, model: str, credential_id: str + ) -> dict | None: """ - Validate custom model credentials. - - :param model_type: model type - :param model: model name - :param credentials: model credentials + Get a specific provider credential by ID. + :param credential_id: Credential ID :return: """ - # get provider model - provider_model_record = self._get_custom_model_credentials(model_type, model) - - # Get provider credential secret variables - provider_credential_secret_variables = self.extract_secret_variables( + model_credential_secret_variables = self.extract_secret_variables( self.provider.model_credential_schema.credential_form_schemas if self.provider.model_credential_schema else [] ) - if provider_model_record: - try: - original_credentials = ( - json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} - ) - except JSONDecodeError: - original_credentials = {} - - # decrypt credentials - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - if value == HIDDEN_VALUE and key in original_credentials: - credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) - - model_provider_factory = ModelProviderFactory(self.tenant_id) - credentials = model_provider_factory.model_credentials_validate( - provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials - ) - - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - credentials[key] = encrypter.encrypt_token(self.tenant_id, value) - - return provider_model_record, credentials - - def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None: - """ - Add or update custom model credentials. - - :param model_type: model type - :param model: model name - :param credentials: model credentials - :return: - """ - # validate custom model config - provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials) - - # save provider model - # Note: Do not switch the preferred provider, which allows users to use quotas first - if provider_model_record: - provider_model_record.encrypted_config = json.dumps(credentials) - provider_model_record.is_valid = True - provider_model_record.updated_at = naive_utc_now() - db.session.commit() - else: - provider_model_record = ProviderModel() - provider_model_record.tenant_id = self.tenant_id - provider_model_record.provider_name = self.provider.provider - provider_model_record.model_name = model - provider_model_record.model_type = model_type.to_origin_model_type() - provider_model_record.encrypted_config = json.dumps(credentials) - provider_model_record.is_valid = True - db.session.add(provider_model_record) - db.session.commit() - - provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL, - ) - - provider_model_credentials_cache.delete() - - def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None: - """ - Delete custom model credentials. - :param model_type: model type - :param model: model name - :return: - """ - # get provider model - provider_model_record = self._get_custom_model_credentials(model_type, model) - - # delete provider model - if provider_model_record: - db.session.delete(provider_model_record) - db.session.commit() - - provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL, + with Session(db.engine) as session: + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.id == credential_id, + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) - provider_model_credentials_cache.delete() + credential_record = session.execute(stmt).scalar_one_or_none() - def _get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None: + if not credential_record or not credential_record.encrypted_config: + raise ValueError(f"Credential with id {credential_id} not found.") + + try: + credentials = json.loads(credential_record.encrypted_config) + except JSONDecodeError: + credentials = {} + + # Decrypt secret variables + for key in model_credential_secret_variables: + if key in credentials and credentials[key] is not None: + try: + credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key]) + except Exception: + pass + + current_credential_id = credential_record.id + current_credential_name = credential_record.credential_name + credentials = self.obfuscated_credentials( + credentials=credentials, + credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas + if self.provider.model_credential_schema + else [], + ) + + return { + "current_credential_id": current_credential_id, + "current_credential_name": current_credential_name, + "credentials": credentials, + } + + def _check_custom_model_credential_name_exists( + self, model_type: ModelType, model: str, credential_name: str, session: Session, exclude_id: str | None = None + ) -> bool: + """ + not allowed same name when create or update a credential + """ + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ProviderModelCredential.credential_name == credential_name, + ) + if exclude_id: + stmt = stmt.where(ProviderModelCredential.id != exclude_id) + return session.execute(stmt).scalar_one_or_none() is not None + + def get_custom_model_credential( + self, model_type: ModelType, model: str, credential_id: str | None + ) -> Optional[dict]: + """ + Get custom model credentials. + + :param model_type: model type + :param model: model name + :return: + """ + # If credential_id is provided, return the specific credential + if credential_id: + return self._get_specific_custom_model_credential( + model_type=model_type, model=model, credential_id=credential_id + ) + + for model_configuration in self.custom_configuration.models: + if ( + model_configuration.model_type == model_type + and model_configuration.model == model + and model_configuration.credentials + ): + current_credential_id = model_configuration.current_credential_id + current_credential_name = model_configuration.current_credential_name + credentials = self.obfuscated_credentials( + credentials=model_configuration.credentials, + credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas + if self.provider.model_credential_schema + else [], + ) + return { + "current_credential_id": current_credential_id, + "current_credential_name": current_credential_name, + "credentials": credentials, + } + return None + + def validate_custom_model_credentials( + self, + model_type: ModelType, + model: str, + credentials: dict, + credential_id: str = "", + session: Session | None = None, + ) -> dict: + """ + Validate custom model credentials. + + :param model_type: model type + :param model: model name + :param credentials: model credentials dict + :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate + :return: + """ + + def _validate(s: Session) -> dict: + # Get provider credential secret variables + provider_credential_secret_variables = self.extract_secret_variables( + self.provider.model_credential_schema.credential_form_schemas + if self.provider.model_credential_schema + else [] + ) + + if credential_id: + try: + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.id == credential_id, + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ) + credential_record = s.execute(stmt).scalar_one_or_none() + original_credentials = ( + json.loads(credential_record.encrypted_config) + if credential_record and credential_record.encrypted_config + else {} + ) + except JSONDecodeError: + original_credentials = {} + + # decrypt credentials + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + if value == HIDDEN_VALUE and key in original_credentials: + credentials[key] = encrypter.decrypt_token( + tenant_id=self.tenant_id, token=original_credentials[key] + ) + + model_provider_factory = ModelProviderFactory(self.tenant_id) + validated_credentials = model_provider_factory.model_credentials_validate( + provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials + ) + + for key, value in validated_credentials.items(): + if key in provider_credential_secret_variables: + validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + + return validated_credentials + + if session: + return _validate(session) + else: + with Session(db.engine) as new_session: + return _validate(new_session) + + def create_custom_model_credential( + self, model_type: ModelType, model: str, credentials: dict, credential_name: str + ) -> None: + """ + Create a custom model credential. + + :param model_type: model type + :param model: model name + :param credentials: model credentials dict + :return: + """ + with Session(db.engine) as session: + if self._check_custom_model_credential_name_exists( + model=model, model_type=model_type, credential_name=credential_name, session=session + ): + raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") + # validate custom model config + credentials = self.validate_custom_model_credentials( + model_type=model_type, model=model, credentials=credentials, session=session + ) + provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) + + try: + credential = ProviderModelCredential( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_name=model, + model_type=model_type.to_origin_model_type(), + encrypted_config=json.dumps(credentials), + credential_name=credential_name, + ) + session.add(credential) + session.flush() + + # save provider model + if not provider_model_record: + provider_model_record = ProviderModel( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_name=model, + model_type=model_type.to_origin_model_type(), + credential_id=credential.id, + is_valid=True, + ) + session.add(provider_model_record) + + session.commit() + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_model_record.id, + cache_type=ProviderCredentialsCacheType.MODEL, + ) + provider_model_credentials_cache.delete() + except Exception: + session.rollback() + raise + + def update_custom_model_credential( + self, model_type: ModelType, model: str, credentials: dict, credential_name: str, credential_id: str + ) -> None: + """ + Update a custom model credential. + + :param model_type: model type + :param model: model name + :param credentials: model credentials dict + :param credential_name: credential name + :param credential_id: credential id + :return: + """ + with Session(db.engine) as session: + if self._check_custom_model_credential_name_exists( + model=model, + model_type=model_type, + credential_name=credential_name, + session=session, + exclude_id=credential_id, + ): + raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") + # validate custom model config + credentials = self.validate_custom_model_credentials( + model_type=model_type, + model=model, + credentials=credentials, + credential_id=credential_id, + session=session, + ) + provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) + + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.id == credential_id, + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ) + credential_record = session.execute(stmt).scalar_one_or_none() + if not credential_record: + raise ValueError("Credential record not found.") + + try: + # Update credential + credential_record.encrypted_config = json.dumps(credentials) + credential_record.credential_name = credential_name + credential_record.updated_at = naive_utc_now() + session.commit() + + if provider_model_record and provider_model_record.credential_id == credential_id: + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_model_record.id, + cache_type=ProviderCredentialsCacheType.MODEL, + ) + provider_model_credentials_cache.delete() + + self._update_load_balancing_configs_with_credential( + credential_id=credential_id, + credential_record=credential_record, + credential_source="custom_model", + session=session, + ) + except Exception: + session.rollback() + raise + + def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None: + """ + Delete a saved provider credential (by credential_id). + + :param credential_id: credential id + :return: + """ + with Session(db.engine) as session: + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.id == credential_id, + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ) + credential_record = session.execute(stmt).scalar_one_or_none() + if not credential_record: + raise ValueError("Credential record not found.") + + lb_stmt = select(LoadBalancingModelConfig).where( + LoadBalancingModelConfig.tenant_id == self.tenant_id, + LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.credential_id == credential_id, + LoadBalancingModelConfig.credential_source_type == "custom_model", + ) + lb_configs_using_credential = session.execute(lb_stmt).scalars().all() + + try: + for lb_config in lb_configs_using_credential: + lb_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=lb_config.id, + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, + ) + lb_credentials_cache.delete() + lb_config.credential_id = None + lb_config.encrypted_config = None + lb_config.enabled = False + lb_config.name = "__delete__" + lb_config.updated_at = naive_utc_now() + session.add(lb_config) + + # Check if this is the currently active credential + provider_model_record = self._get_custom_model_record(model_type, model, session=session) + + # Check available credentials count BEFORE deleting + # if this is the last credential, we need to delete the custom model record + count_stmt = select(func.count(ProviderModelCredential.id)).where( + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ) + available_credentials_count = session.execute(count_stmt).scalar() or 0 + session.delete(credential_record) + + if provider_model_record and available_credentials_count <= 1: + # If all credentials are deleted, delete the custom model record + session.delete(provider_model_record) + elif provider_model_record and provider_model_record.credential_id == credential_id: + provider_model_record.credential_id = None + provider_model_record.updated_at = naive_utc_now() + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_model_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ) + provider_model_credentials_cache.delete() + + session.commit() + + except Exception: + session.rollback() + raise + + def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str) -> None: + """ + if model list exist this custom model, switch the custom model credential. + if model list not exist this custom model, use the credential to add a new custom model record. + + :param model_type: model type + :param model: model name + :param credential_id: credential id + :return: + """ + with Session(db.engine) as session: + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.id == credential_id, + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ) + credential_record = session.execute(stmt).scalar_one_or_none() + if not credential_record: + raise ValueError("Credential record not found.") + + # validate custom model config + provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) + + if not provider_model_record: + # create provider model record + provider_model_record = ProviderModel( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_name=model, + model_type=model_type.to_origin_model_type(), + credential_id=credential_id, + ) + else: + if provider_model_record.credential_id == credential_record.id: + raise ValueError("Can't add same credential") + provider_model_record.credential_id = credential_record.id + provider_model_record.updated_at = naive_utc_now() + session.add(provider_model_record) + session.commit() + + def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None: + """ + switch the custom model credential. + + :param model_type: model type + :param model: model name + :param credential_id: credential id + :return: + """ + with Session(db.engine) as session: + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.id == credential_id, + ProviderModelCredential.tenant_id == self.tenant_id, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ) + credential_record = session.execute(stmt).scalar_one_or_none() + if not credential_record: + raise ValueError("Credential record not found.") + + provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) + if not provider_model_record: + raise ValueError("The custom model record not found.") + + provider_model_record.credential_id = credential_record.id + provider_model_record.updated_at = naive_utc_now() + session.add(provider_model_record) + session.commit() + + def delete_custom_model(self, model_type: ModelType, model: str) -> None: + """ + Delete custom model. + :param model_type: model type + :param model: model name + :return: + """ + with Session(db.engine) as session: + # get provider model + provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) + + # delete provider model + if provider_model_record: + session.delete(provider_model_record) + session.commit() + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_model_record.id, + cache_type=ProviderCredentialsCacheType.MODEL, + ) + + provider_model_credentials_cache.delete() + + def _get_provider_model_setting( + self, model_type: ModelType, model: str, session: Session + ) -> ProviderModelSetting | None: """ Get provider model setting. """ @@ -479,16 +1129,13 @@ class ProviderConfiguration(BaseModel): if model_provider_id.is_langgenius(): provider_names.append(model_provider_id.provider_name) - return ( - db.session.query(ProviderModelSetting) - .where( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name.in_(provider_names), - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model, - ) - .first() + stmt = select(ProviderModelSetting).where( + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name.in_(provider_names), + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, ) + return session.execute(stmt).scalars().first() def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ @@ -497,21 +1144,23 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_setting = self._get_provider_model_setting(model_type, model) + with Session(db.engine) as session: + model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session) - if model_setting: - model_setting.enabled = True - model_setting.updated_at = naive_utc_now() - db.session.commit() - else: - model_setting = ProviderModelSetting() - model_setting.tenant_id = self.tenant_id - model_setting.provider_name = self.provider.provider - model_setting.model_type = model_type.to_origin_model_type() - model_setting.model_name = model - model_setting.enabled = True - db.session.add(model_setting) - db.session.commit() + if model_setting: + model_setting.enabled = True + model_setting.updated_at = naive_utc_now() + + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + enabled=True, + ) + session.add(model_setting) + session.commit() return model_setting @@ -522,21 +1171,22 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_setting = self._get_provider_model_setting(model_type, model) + with Session(db.engine) as session: + model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session) - if model_setting: - model_setting.enabled = False - model_setting.updated_at = naive_utc_now() - db.session.commit() - else: - model_setting = ProviderModelSetting() - model_setting.tenant_id = self.tenant_id - model_setting.provider_name = self.provider.provider - model_setting.model_type = model_type.to_origin_model_type() - model_setting.model_name = model - model_setting.enabled = False - db.session.add(model_setting) - db.session.commit() + if model_setting: + model_setting.enabled = False + model_setting.updated_at = naive_utc_now() + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + enabled=False, + ) + session.add(model_setting) + session.commit() return model_setting @@ -547,27 +1197,8 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - return self._get_provider_model_setting(model_type, model) - - def _get_load_balancing_config(self, model_type: ModelType, model: str) -> Optional[LoadBalancingModelConfig]: - """ - Get load balancing config. - """ - model_provider_id = ModelProviderID(self.provider.provider) - provider_names = [self.provider.provider] - if model_provider_id.is_langgenius(): - provider_names.append(model_provider_id.provider_name) - - return ( - db.session.query(LoadBalancingModelConfig) - .where( - LoadBalancingModelConfig.tenant_id == self.tenant_id, - LoadBalancingModelConfig.provider_name.in_(provider_names), - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model, - ) - .first() - ) + with Session(db.engine) as session: + return self._get_provider_model_setting(model_type=model_type, model=model, session=session) def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ @@ -581,35 +1212,32 @@ class ProviderConfiguration(BaseModel): if model_provider_id.is_langgenius(): provider_names.append(model_provider_id.provider_name) - load_balancing_config_count = ( - db.session.query(LoadBalancingModelConfig) - .where( + with Session(db.engine) as session: + stmt = select(func.count(LoadBalancingModelConfig.id)).where( LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(provider_names), LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, ) - .count() - ) + load_balancing_config_count = session.execute(stmt).scalar() or 0 + if load_balancing_config_count <= 1: + raise ValueError("Model load balancing configuration must be more than 1.") - if load_balancing_config_count <= 1: - raise ValueError("Model load balancing configuration must be more than 1.") + model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session) - model_setting = self._get_provider_model_setting(model_type, model) - - if model_setting: - model_setting.load_balancing_enabled = True - model_setting.updated_at = naive_utc_now() - db.session.commit() - else: - model_setting = ProviderModelSetting() - model_setting.tenant_id = self.tenant_id - model_setting.provider_name = self.provider.provider - model_setting.model_type = model_type.to_origin_model_type() - model_setting.model_name = model - model_setting.load_balancing_enabled = True - db.session.add(model_setting) - db.session.commit() + if model_setting: + model_setting.load_balancing_enabled = True + model_setting.updated_at = naive_utc_now() + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + load_balancing_enabled=True, + ) + session.add(model_setting) + session.commit() return model_setting @@ -620,35 +1248,23 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_provider_id = ModelProviderID(self.provider.provider) - provider_names = [self.provider.provider] - if model_provider_id.is_langgenius(): - provider_names.append(model_provider_id.provider_name) - model_setting = ( - db.session.query(ProviderModelSetting) - .where( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name.in_(provider_names), - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model, - ) - .first() - ) + with Session(db.engine) as session: + model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session) - if model_setting: - model_setting.load_balancing_enabled = False - model_setting.updated_at = naive_utc_now() - db.session.commit() - else: - model_setting = ProviderModelSetting() - model_setting.tenant_id = self.tenant_id - model_setting.provider_name = self.provider.provider - model_setting.model_type = model_type.to_origin_model_type() - model_setting.model_name = model - model_setting.load_balancing_enabled = False - db.session.add(model_setting) - db.session.commit() + if model_setting: + model_setting.load_balancing_enabled = False + model_setting.updated_at = naive_utc_now() + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + load_balancing_enabled=False, + ) + session.add(model_setting) + session.commit() return model_setting @@ -664,7 +1280,7 @@ class ProviderConfiguration(BaseModel): # Get model instance of LLM return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type) - def get_model_schema(self, model_type: ModelType, model: str, credentials: dict) -> AIModelEntity | None: + def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None: """ Get model schema """ @@ -673,7 +1289,7 @@ class ProviderConfiguration(BaseModel): provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) - def switch_preferred_provider_type(self, provider_type: ProviderType) -> None: + def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None) -> None: """ Switch preferred provider type. :param provider_type: @@ -685,31 +1301,35 @@ class ProviderConfiguration(BaseModel): if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled: return - # get preferred provider - model_provider_id = ModelProviderID(self.provider.provider) - provider_names = [self.provider.provider] - if model_provider_id.is_langgenius(): - provider_names.append(model_provider_id.provider_name) + def _switch(s: Session) -> None: + # get preferred provider + model_provider_id = ModelProviderID(self.provider.provider) + provider_names = [self.provider.provider] + if model_provider_id.is_langgenius(): + provider_names.append(model_provider_id.provider_name) - preferred_model_provider = ( - db.session.query(TenantPreferredModelProvider) - .where( + stmt = select(TenantPreferredModelProvider).where( TenantPreferredModelProvider.tenant_id == self.tenant_id, TenantPreferredModelProvider.provider_name.in_(provider_names), ) - .first() - ) + preferred_model_provider = s.execute(stmt).scalars().first() - if preferred_model_provider: - preferred_model_provider.preferred_provider_type = provider_type.value + if preferred_model_provider: + preferred_model_provider.preferred_provider_type = provider_type.value + else: + preferred_model_provider = TenantPreferredModelProvider( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + preferred_provider_type=provider_type.value, + ) + s.add(preferred_model_provider) + s.commit() + + if session: + return _switch(session) else: - preferred_model_provider = TenantPreferredModelProvider() - preferred_model_provider.tenant_id = self.tenant_id - preferred_model_provider.provider_name = self.provider.provider - preferred_model_provider.preferred_provider_type = provider_type.value - db.session.add(preferred_model_provider) - - db.session.commit() + with Session(db.engine) as session: + return _switch(session) def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: """ @@ -973,14 +1593,24 @@ class ProviderConfiguration(BaseModel): status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE load_balancing_enabled = False + has_invalid_load_balancing_configs = False if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: model_setting = model_setting_map[m.model_type][m.model] if model_setting.enabled is False: status = ModelStatus.DISABLED - if len(model_setting.load_balancing_configs) > 1: + provider_model_lb_configs = [ + config + for config in model_setting.load_balancing_configs + if config.credential_source_type != "custom_model" + ] + + if len(provider_model_lb_configs) > 1: load_balancing_enabled = True + if any(config.name == "__delete__" for config in provider_model_lb_configs): + has_invalid_load_balancing_configs = True + provider_models.append( ModelWithProviderEntity( model=m.model, @@ -993,6 +1623,7 @@ class ProviderConfiguration(BaseModel): provider=SimpleModelProviderEntity(self.provider), status=status, load_balancing_enabled=load_balancing_enabled, + has_invalid_load_balancing_configs=has_invalid_load_balancing_configs, ) ) @@ -1017,6 +1648,7 @@ class ProviderConfiguration(BaseModel): status = ModelStatus.ACTIVE load_balancing_enabled = False + has_invalid_load_balancing_configs = False if ( custom_model_schema.model_type in model_setting_map and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] @@ -1025,9 +1657,21 @@ class ProviderConfiguration(BaseModel): if model_setting.enabled is False: status = ModelStatus.DISABLED - if len(model_setting.load_balancing_configs) > 1: + custom_model_lb_configs = [ + config + for config in model_setting.load_balancing_configs + if config.credential_source_type != "provider" + ] + + if len(custom_model_lb_configs) > 1: load_balancing_enabled = True + if any(config.name == "__delete__" for config in custom_model_lb_configs): + has_invalid_load_balancing_configs = True + + if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials: + status = ModelStatus.CREDENTIAL_REMOVED + provider_models.append( ModelWithProviderEntity( model=custom_model_schema.model, @@ -1040,6 +1684,7 @@ class ProviderConfiguration(BaseModel): provider=SimpleModelProviderEntity(self.provider), status=status, load_balancing_enabled=load_balancing_enabled, + has_invalid_load_balancing_configs=has_invalid_load_balancing_configs, ) ) diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index a5a6e62bd..1b87bffe5 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -69,6 +69,15 @@ class QuotaConfiguration(BaseModel): restrict_models: list[RestrictModel] = [] +class CredentialConfiguration(BaseModel): + """ + Model class for credential configuration. + """ + + credential_id: str + credential_name: str + + class SystemConfiguration(BaseModel): """ Model class for provider system configuration. @@ -86,6 +95,9 @@ class CustomProviderConfiguration(BaseModel): """ credentials: dict + current_credential_id: Optional[str] = None + current_credential_name: Optional[str] = None + available_credentials: list[CredentialConfiguration] = [] class CustomModelConfiguration(BaseModel): @@ -95,7 +107,10 @@ class CustomModelConfiguration(BaseModel): model: str model_type: ModelType - credentials: dict + credentials: dict | None + current_credential_id: Optional[str] = None + current_credential_name: Optional[str] = None + available_model_credentials: list[CredentialConfiguration] = [] # pydantic configs model_config = ConfigDict(protected_namespaces=()) @@ -118,6 +133,7 @@ class ModelLoadBalancingConfiguration(BaseModel): id: str name: str credentials: dict + credential_source_type: str | None = None class ModelSettings(BaseModel): diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index f8590b38f..24cf69a50 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -201,7 +201,7 @@ class ModelProviderFactory: return filtered_credentials def get_model_schema( - self, *, provider: str, model_type: ModelType, model: str, credentials: dict + self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None ) -> AIModelEntity | None: """ Get model schema diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 39fec951b..28a4ce077 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -12,6 +12,7 @@ from configs import dify_config from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle from core.entities.provider_entities import ( + CredentialConfiguration, CustomConfiguration, CustomModelConfiguration, CustomProviderConfiguration, @@ -40,7 +41,9 @@ from extensions.ext_redis import redis_client from models.provider import ( LoadBalancingModelConfig, Provider, + ProviderCredential, ProviderModel, + ProviderModelCredential, ProviderModelSetting, ProviderType, TenantDefaultModel, @@ -488,6 +491,61 @@ class ProviderManager: return provider_name_to_provider_load_balancing_model_configs_dict + @staticmethod + def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]: + """ + Get provider all credentials. + + :param tenant_id: workspace id + :param provider_name: provider name + :return: + """ + with Session(db.engine, expire_on_commit=False) as session: + stmt = ( + select(ProviderCredential) + .where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name) + .order_by(ProviderCredential.created_at.desc()) + ) + + available_credentials = session.scalars(stmt).all() + + return [ + CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name) + for credential in available_credentials + ] + + @staticmethod + def get_provider_model_available_credentials( + tenant_id: str, provider_name: str, model_name: str, model_type: str + ) -> list[CredentialConfiguration]: + """ + Get provider custom model all credentials. + + :param tenant_id: workspace id + :param provider_name: provider name + :param model_name: model name + :param model_type: model type + :return: + """ + with Session(db.engine, expire_on_commit=False) as session: + stmt = ( + select(ProviderModelCredential) + .where( + ProviderModelCredential.tenant_id == tenant_id, + ProviderModelCredential.provider_name == provider_name, + ProviderModelCredential.model_name == model_name, + ProviderModelCredential.model_type == model_type, + ) + .order_by(ProviderModelCredential.created_at.desc()) + ) + + available_credentials = session.scalars(stmt).all() + + return [ + CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name) + for credential in available_credentials + ] + @staticmethod def _init_trial_provider_records( tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]] @@ -590,9 +648,6 @@ class ProviderManager: if provider_record.provider_type == ProviderType.SYSTEM.value: continue - if not provider_record.encrypted_config: - continue - custom_provider_record = provider_record # Get custom provider credentials @@ -611,8 +666,8 @@ class ProviderManager: try: # fix origin data if custom_provider_record.encrypted_config is None: - raise ValueError("No credentials found") - if not custom_provider_record.encrypted_config.startswith("{"): + provider_credentials = {} + elif not custom_provider_record.encrypted_config.startswith("{"): provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} else: provider_credentials = json.loads(custom_provider_record.encrypted_config) @@ -637,7 +692,14 @@ class ProviderManager: else: provider_credentials = cached_provider_credentials - custom_provider_configuration = CustomProviderConfiguration(credentials=provider_credentials) + custom_provider_configuration = CustomProviderConfiguration( + credentials=provider_credentials, + current_credential_name=custom_provider_record.credential_name, + current_credential_id=custom_provider_record.credential_id, + available_credentials=self.get_provider_available_credentials( + tenant_id, custom_provider_record.provider_name + ), + ) # Get provider model credential secret variables model_credential_secret_variables = self._extract_secret_variables( @@ -649,8 +711,12 @@ class ProviderManager: # Get custom provider model credentials custom_model_configurations = [] for provider_model_record in provider_model_records: - if not provider_model_record.encrypted_config: - continue + available_model_credentials = self.get_provider_model_available_credentials( + tenant_id, + provider_model_record.provider_name, + provider_model_record.model_name, + provider_model_record.model_type, + ) provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL @@ -659,7 +725,7 @@ class ProviderManager: # Get cached provider model credentials cached_provider_model_credentials = provider_model_credentials_cache.get() - if not cached_provider_model_credentials: + if not cached_provider_model_credentials and provider_model_record.encrypted_config: try: provider_model_credentials = json.loads(provider_model_record.encrypted_config) except JSONDecodeError: @@ -688,6 +754,9 @@ class ProviderManager: model=provider_model_record.model_name, model_type=ModelType.value_of(provider_model_record.model_type), credentials=provider_model_credentials, + current_credential_id=provider_model_record.credential_id, + current_credential_name=provider_model_record.credential_name, + available_model_credentials=available_model_credentials, ) ) @@ -899,6 +968,18 @@ class ProviderManager: load_balancing_model_config.model_name == provider_model_setting.model_name and load_balancing_model_config.model_type == provider_model_setting.model_type ): + if load_balancing_model_config.name == "__delete__": + # to calculate current model whether has invalidate lb configs + load_balancing_configs.append( + ModelLoadBalancingConfiguration( + id=load_balancing_model_config.id, + name=load_balancing_model_config.name, + credentials={}, + credential_source_type=load_balancing_model_config.credential_source_type, + ) + ) + continue + if not load_balancing_model_config.enabled: continue @@ -955,6 +1036,7 @@ class ProviderManager: id=load_balancing_model_config.id, name=load_balancing_model_config.name, credentials=provider_model_credentials, + credential_source_type=load_balancing_model_config.credential_source_type, ) ) diff --git a/api/migrations/versions/2025_08_09_1553-e8446f481c1e_add_provider_credential_pool_support.py b/api/migrations/versions/2025_08_09_1553-e8446f481c1e_add_provider_credential_pool_support.py new file mode 100644 index 000000000..87b42346d --- /dev/null +++ b/api/migrations/versions/2025_08_09_1553-e8446f481c1e_add_provider_credential_pool_support.py @@ -0,0 +1,177 @@ +"""Add provider multi credential support + +Revision ID: e8446f481c1e +Revises: 8bcc02c9bd07 +Create Date: 2025-08-09 15:53:54.341341 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.sql import table, column +import uuid + +# revision identifiers, used by Alembic. +revision = 'e8446f481c1e' +down_revision = 'fa8b0fa6f407' +branch_labels = None +depends_on = None + + +def upgrade(): + # Create provider_credentials table + op.create_table('provider_credentials', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('provider_name', sa.String(length=255), nullable=False), + sa.Column('credential_name', sa.String(length=255), nullable=False), + sa.Column('encrypted_config', sa.Text(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='provider_credential_pkey') + ) + + # Create index for provider_credentials + with op.batch_alter_table('provider_credentials', schema=None) as batch_op: + batch_op.create_index('provider_credential_tenant_provider_idx', ['tenant_id', 'provider_name'], unique=False) + + # Add credential_id to providers table + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('credential_id', models.types.StringUUID(), nullable=True)) + + # Add credential_id to load_balancing_model_configs table + with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('credential_id', models.types.StringUUID(), nullable=True)) + + migrate_existing_providers_data() + + # Remove encrypted_config column from providers table after migration + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.drop_column('encrypted_config') + + +def migrate_existing_providers_data(): + """migrate providers table data to provider_credentials""" + + # Define table structure for data manipulation + providers_table = table('providers', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('encrypted_config', sa.Text()), + column('created_at', sa.DateTime()), + column('updated_at', sa.DateTime()), + column('credential_id', models.types.StringUUID()), + ) + + provider_credential_table = table('provider_credentials', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('credential_name', sa.String()), + column('encrypted_config', sa.Text()), + column('created_at', sa.DateTime()), + column('updated_at', sa.DateTime()) + ) + + # Get database connection + conn = op.get_bind() + + # Query all existing providers data + existing_providers = conn.execute( + sa.select(providers_table.c.id, providers_table.c.tenant_id, + providers_table.c.provider_name, providers_table.c.encrypted_config, + providers_table.c.created_at, providers_table.c.updated_at) + .where(providers_table.c.encrypted_config.isnot(None)) + ).fetchall() + + # Iterate through each provider and insert into provider_credentials + for provider in existing_providers: + credential_id = str(uuid.uuid4()) + if not provider.encrypted_config or provider.encrypted_config.strip() == '': + continue + + # Insert into provider_credentials table + conn.execute( + provider_credential_table.insert().values( + id=credential_id, + tenant_id=provider.tenant_id, + provider_name=provider.provider_name, + credential_name='API_KEY1', # Use a default name + encrypted_config=provider.encrypted_config, + created_at=provider.created_at, + updated_at=provider.updated_at + ) + ) + + # Update original providers table, set credential_id + conn.execute( + providers_table.update() + .where(providers_table.c.id == provider.id) + .values( + credential_id=credential_id, + ) + ) + +def downgrade(): + # Re-add encrypted_config column to providers table + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True)) + + # Migrate data back from provider_credentials to providers + migrate_data_back_to_providers() + + # Remove credential_id columns + with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: + batch_op.drop_column('credential_id') + + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.drop_column('credential_id') + + # Drop provider_credentials table + op.drop_table('provider_credentials') + + +def migrate_data_back_to_providers(): + """Migrate data back from provider_credentials to providers table for downgrade""" + + # Define table structure for data manipulation + providers_table = table('providers', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('encrypted_config', sa.Text()), + column('credential_id', models.types.StringUUID()), + ) + + provider_credential_table = table('provider_credentials', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('credential_name', sa.String()), + column('encrypted_config', sa.Text()), + ) + + # Get database connection + conn = op.get_bind() + + # Query providers that have credential_id + providers_with_credentials = conn.execute( + sa.select(providers_table.c.id, providers_table.c.credential_id) + .where(providers_table.c.credential_id.isnot(None)) + ).fetchall() + + # For each provider, get the credential data and update providers table + for provider in providers_with_credentials: + credential = conn.execute( + sa.select(provider_credential_table.c.encrypted_config) + .where(provider_credential_table.c.id == provider.credential_id) + ).fetchone() + + if credential: + # Update providers table with encrypted_config from credential + conn.execute( + providers_table.update() + .where(providers_table.c.id == provider.id) + .values(encrypted_config=credential.encrypted_config) + ) \ No newline at end of file diff --git a/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py b/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py new file mode 100644 index 000000000..bec1a4540 --- /dev/null +++ b/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py @@ -0,0 +1,186 @@ +"""Add provider model multi credential support + +Revision ID: 0e154742a5fa +Revises: e8446f481c1e +Create Date: 2025-08-13 16:05:42.657730 + +""" +import uuid + +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.sql import table, column + + +# revision identifiers, used by Alembic. +revision = '0e154742a5fa' +down_revision = 'e8446f481c1e' +branch_labels = None +depends_on = None + + +def upgrade(): + # Create provider_model_credentials table + op.create_table('provider_model_credentials', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('provider_name', sa.String(length=255), nullable=False), + sa.Column('model_name', sa.String(length=255), nullable=False), + sa.Column('model_type', sa.String(length=40), nullable=False), + sa.Column('credential_name', sa.String(length=255), nullable=False), + sa.Column('encrypted_config', sa.Text(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='provider_model_credential_pkey') + ) + + # Create index for provider_model_credentials + with op.batch_alter_table('provider_model_credentials', schema=None) as batch_op: + batch_op.create_index('provider_model_credential_tenant_provider_model_idx', ['tenant_id', 'provider_name', 'model_name', 'model_type'], unique=False) + + # Add credential_id to provider_models table + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.add_column(sa.Column('credential_id', models.types.StringUUID(), nullable=True)) + + + # Add credential_source_type to load_balancing_model_configs table + with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('credential_source_type', sa.String(length=40), nullable=True)) + + # Migrate existing provider_models data + migrate_existing_provider_models_data() + + # Remove encrypted_config column from provider_models table after migration + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.drop_column('encrypted_config') + + +def migrate_existing_provider_models_data(): + """migrate provider_models table data to provider_model_credentials""" + + # Define table structure for data manipulation + provider_models_table = table('provider_models', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('model_name', sa.String()), + column('model_type', sa.String()), + column('encrypted_config', sa.Text()), + column('created_at', sa.DateTime()), + column('updated_at', sa.DateTime()), + column('credential_id', models.types.StringUUID()), + ) + + provider_model_credentials_table = table('provider_model_credentials', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('model_name', sa.String()), + column('model_type', sa.String()), + column('credential_name', sa.String()), + column('encrypted_config', sa.Text()), + column('created_at', sa.DateTime()), + column('updated_at', sa.DateTime()) + ) + + + # Get database connection + conn = op.get_bind() + + # Query all existing provider_models data with encrypted_config + existing_provider_models = conn.execute( + sa.select(provider_models_table.c.id, provider_models_table.c.tenant_id, + provider_models_table.c.provider_name, provider_models_table.c.model_name, + provider_models_table.c.model_type, provider_models_table.c.encrypted_config, + provider_models_table.c.created_at, provider_models_table.c.updated_at) + .where(provider_models_table.c.encrypted_config.isnot(None)) + ).fetchall() + + # Iterate through each provider_model and insert into provider_model_credentials + for provider_model in existing_provider_models: + if not provider_model.encrypted_config or provider_model.encrypted_config.strip() == '': + continue + + credential_id = str(uuid.uuid4()) + + # Insert into provider_model_credentials table + conn.execute( + provider_model_credentials_table.insert().values( + id=credential_id, + tenant_id=provider_model.tenant_id, + provider_name=provider_model.provider_name, + model_name=provider_model.model_name, + model_type=provider_model.model_type, + credential_name='API_KEY1', # Use a default name + encrypted_config=provider_model.encrypted_config, + created_at=provider_model.created_at, + updated_at=provider_model.updated_at + ) + ) + + # Update original provider_models table, set credential_id + conn.execute( + provider_models_table.update() + .where(provider_models_table.c.id == provider_model.id) + .values(credential_id=credential_id) + ) + + +def downgrade(): + # Re-add encrypted_config column to provider_models table + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True)) + + # Migrate data back from provider_model_credentials to provider_models + migrate_data_back_to_provider_models() + + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.drop_column('credential_id') + + # Remove credential_source_type column from load_balancing_model_configs + with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: + batch_op.drop_column('credential_source_type') + + # Drop provider_model_credentials table + op.drop_table('provider_model_credentials') + + +def migrate_data_back_to_provider_models(): + """Migrate data back from provider_model_credentials to provider_models table for downgrade""" + + # Define table structure for data manipulation + provider_models_table = table('provider_models', + column('id', models.types.StringUUID()), + column('encrypted_config', sa.Text()), + column('credential_id', models.types.StringUUID()), + ) + + provider_model_credentials_table = table('provider_model_credentials', + column('id', models.types.StringUUID()), + column('encrypted_config', sa.Text()), + ) + + # Get database connection + conn = op.get_bind() + + # Query provider_models that have credential_id + provider_models_with_credentials = conn.execute( + sa.select(provider_models_table.c.id, provider_models_table.c.credential_id) + .where(provider_models_table.c.credential_id.isnot(None)) + ).fetchall() + + # For each provider_model, get the credential data and update provider_models table + for provider_model in provider_models_with_credentials: + credential = conn.execute( + sa.select(provider_model_credentials_table.c.encrypted_config) + .where(provider_model_credentials_table.c.id == provider_model.credential_id) + ).fetchone() + + if credential: + # Update provider_models table with encrypted_config from credential + conn.execute( + provider_models_table.update() + .where(provider_models_table.c.id == provider_model.id) + .values(encrypted_config=credential.encrypted_config) + ) diff --git a/api/models/provider.py b/api/models/provider.py index 4ea2c59fd..e75b26fd3 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,5 +1,6 @@ from datetime import datetime from enum import Enum +from functools import cached_property from typing import Optional import sqlalchemy as sa @@ -7,6 +8,7 @@ from sqlalchemy import DateTime, String, func, text from sqlalchemy.orm import Mapped, mapped_column from .base import Base +from .engine import db from .types import StringUUID @@ -60,9 +62,9 @@ class Provider(Base): provider_type: Mapped[str] = mapped_column( String(40), nullable=False, server_default=text("'custom'::character varying") ) - encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) quota_type: Mapped[Optional[str]] = mapped_column( String(40), nullable=True, server_default=text("''::character varying") @@ -79,6 +81,21 @@ class Provider(Base): f" provider_type='{self.provider_type}')>" ) + @cached_property + def credential(self): + if self.credential_id: + return db.session.query(ProviderCredential).where(ProviderCredential.id == self.credential_id).first() + + @property + def credential_name(self): + credential = self.credential + return credential.credential_name if credential else None + + @property + def encrypted_config(self): + credential = self.credential + return credential.encrypted_config if credential else None + @property def token_is_set(self): """ @@ -116,11 +133,30 @@ class ProviderModel(Base): provider_name: Mapped[str] = mapped_column(String(255), nullable=False) model_name: Mapped[str] = mapped_column(String(255), nullable=False) model_type: Mapped[str] = mapped_column(String(40), nullable=False) - encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) + credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + @cached_property + def credential(self): + if self.credential_id: + return ( + db.session.query(ProviderModelCredential) + .where(ProviderModelCredential.id == self.credential_id) + .first() + ) + + @property + def credential_name(self): + credential = self.credential + return credential.credential_name if credential else None + + @property + def encrypted_config(self): + credential = self.credential + return credential.encrypted_config if credential else None + class TenantDefaultModel(Base): __tablename__ = "tenant_default_models" @@ -220,6 +256,56 @@ class LoadBalancingModelConfig(Base): model_type: Mapped[str] = mapped_column(String(40), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) + credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) + credential_source_type: Mapped[Optional[str]] = mapped_column(String(40), nullable=True) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true")) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + + +class ProviderCredential(Base): + """ + Provider credential - stores multiple named credentials for each provider + """ + + __tablename__ = "provider_credentials" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="provider_credential_pkey"), + sa.Index("provider_credential_tenant_provider_idx", "tenant_id", "provider_name"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + credential_name: Mapped[str] = mapped_column(String(255), nullable=False) + encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + + +class ProviderModelCredential(Base): + """ + Provider model credential - stores multiple named credentials for each provider model + """ + + __tablename__ = "provider_model_credentials" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="provider_model_credential_pkey"), + sa.Index( + "provider_model_credential_tenant_provider_model_idx", + "tenant_id", + "provider_name", + "model_name", + "model_type", + ), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) + credential_name: Mapped[str] = mapped_column(String(255), nullable=False) + encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index bc385b2e2..056decda2 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -8,7 +8,12 @@ from core.entities.model_entities import ( ModelWithProviderEntity, ProviderModelWithStatusEntity, ) -from core.entities.provider_entities import ProviderQuotaType, QuotaConfiguration +from core.entities.provider_entities import ( + CredentialConfiguration, + CustomModelConfiguration, + ProviderQuotaType, + QuotaConfiguration, +) from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.provider_entities import ( @@ -36,6 +41,10 @@ class CustomConfigurationResponse(BaseModel): """ status: CustomConfigurationStatus + current_credential_id: Optional[str] = None + current_credential_name: Optional[str] = None + available_credentials: Optional[list[CredentialConfiguration]] = None + custom_models: Optional[list[CustomModelConfiguration]] = None class SystemConfigurationResponse(BaseModel): diff --git a/api/services/errors/app_model_config.py b/api/services/errors/app_model_config.py index c0669ed23..bb5eb62b7 100644 --- a/api/services/errors/app_model_config.py +++ b/api/services/errors/app_model_config.py @@ -3,3 +3,7 @@ from services.errors.base import BaseServiceError class AppModelConfigBrokenError(BaseServiceError): pass + + +class ProviderNotFoundError(BaseServiceError): + pass diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index f8dd70c79..2145b4cdd 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -17,7 +17,7 @@ from core.model_runtime.model_providers.model_provider_factory import ModelProvi from core.provider_manager import ProviderManager from extensions.ext_database import db from libs.datetime_utils import naive_utc_now -from models.provider import LoadBalancingModelConfig +from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential logger = logging.getLogger(__name__) @@ -185,6 +185,7 @@ class ModelLoadBalancingService: "id": load_balancing_config.id, "name": load_balancing_config.name, "credentials": credentials, + "credential_id": load_balancing_config.credential_id, "enabled": load_balancing_config.enabled, "in_cooldown": in_cooldown, "ttl": ttl, @@ -280,7 +281,7 @@ class ModelLoadBalancingService: return inherit_config def update_load_balancing_configs( - self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict] + self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict], config_from: str ) -> None: """ Update load balancing configurations. @@ -289,6 +290,7 @@ class ModelLoadBalancingService: :param model: model name :param model_type: model type :param configs: load balancing configs + :param config_from: predefined-model or custom-model :return: """ # Get all provider configurations of the current workspace @@ -327,8 +329,37 @@ class ModelLoadBalancingService: config_id = config.get("id") name = config.get("name") credentials = config.get("credentials") + credential_id = config.get("credential_id") enabled = config.get("enabled") + if credential_id: + credential_record: ProviderCredential | ProviderModelCredential | None = None + if config_from == "predefined-model": + credential_record = ( + db.session.query(ProviderCredential) + .filter_by( + id=credential_id, + tenant_id=tenant_id, + provider_name=provider_configuration.provider.provider, + ) + .first() + ) + else: + credential_record = ( + db.session.query(ProviderModelCredential) + .filter_by( + id=credential_id, + tenant_id=tenant_id, + provider_name=provider_configuration.provider.provider, + model_name=model, + model_type=model_type_enum.to_origin_model_type(), + ) + .first() + ) + if not credential_record: + raise ValueError(f"Provider credential with id {credential_id} not found") + name = credential_record.credential_name + if not name: raise ValueError("Invalid load balancing config name") @@ -346,11 +377,6 @@ class ModelLoadBalancingService: load_balancing_config = current_load_balancing_configs_dict[config_id] - # check duplicate name - for current_load_balancing_config in current_load_balancing_configs: - if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name: - raise ValueError(f"Load balancing config name {name} already exists") - if credentials: if not isinstance(credentials, dict): raise ValueError("Invalid load balancing config credentials") @@ -377,39 +403,48 @@ class ModelLoadBalancingService: self._clear_credentials_cache(tenant_id, config_id) else: # create load balancing config - if name == "__inherit__": + if name in {"__inherit__", "__delete__"}: raise ValueError("Invalid load balancing config name") - # check duplicate name - for current_load_balancing_config in current_load_balancing_configs: - if current_load_balancing_config.name == name: - raise ValueError(f"Load balancing config name {name} already exists") + if credential_id: + credential_source = "provider" if config_from == "predefined-model" else "custom_model" + assert credential_record is not None + load_balancing_model_config = LoadBalancingModelConfig( + tenant_id=tenant_id, + provider_name=provider_configuration.provider.provider, + model_type=model_type_enum.to_origin_model_type(), + model_name=model, + name=credential_record.credential_name, + encrypted_config=credential_record.encrypted_config, + credential_id=credential_id, + credential_source_type=credential_source, + ) + else: + if not credentials: + raise ValueError("Invalid load balancing config credentials") - if not credentials: - raise ValueError("Invalid load balancing config credentials") + if not isinstance(credentials, dict): + raise ValueError("Invalid load balancing config credentials") - if not isinstance(credentials, dict): - raise ValueError("Invalid load balancing config credentials") + # validate custom provider config + credentials = self._custom_credentials_validate( + tenant_id=tenant_id, + provider_configuration=provider_configuration, + model_type=model_type_enum, + model=model, + credentials=credentials, + validate=False, + ) - # validate custom provider config - credentials = self._custom_credentials_validate( - tenant_id=tenant_id, - provider_configuration=provider_configuration, - model_type=model_type_enum, - model=model, - credentials=credentials, - validate=False, - ) - - # create load balancing config - load_balancing_model_config = LoadBalancingModelConfig( - tenant_id=tenant_id, - provider_name=provider_configuration.provider.provider, - model_type=model_type_enum.to_origin_model_type(), - model_name=model, - name=name, - encrypted_config=json.dumps(credentials), - ) + # create load balancing config + load_balancing_model_config = LoadBalancingModelConfig( + tenant_id=tenant_id, + provider_name=provider_configuration.provider.provider, + model_type=model_type_enum.to_origin_model_type(), + model_name=model, + name=name, + encrypted_config=json.dumps(credentials), + ) db.session.add(load_balancing_model_config) db.session.commit() diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 54197bf94..67c3f0d6b 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -16,6 +16,7 @@ from services.entities.model_provider_entities import ( SimpleProviderEntityResponse, SystemConfigurationResponse, ) +from services.errors.app_model_config import ProviderNotFoundError logger = logging.getLogger(__name__) @@ -28,6 +29,29 @@ class ModelProviderService: def __init__(self) -> None: self.provider_manager = ProviderManager() + def _get_provider_configuration(self, tenant_id: str, provider: str): + """ + Get provider configuration or raise exception if not found. + + Args: + tenant_id: Workspace identifier + provider: Provider name + + Returns: + Provider configuration instance + + Raises: + ProviderNotFoundError: If provider doesn't exist + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configuration = provider_configurations.get(provider) + + if not provider_configuration: + raise ProviderNotFoundError(f"Provider {provider} does not exist.") + + return provider_configuration + def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list[ProviderResponse]: """ get provider list. @@ -46,6 +70,9 @@ class ModelProviderService: if model_type_entity not in provider_configuration.provider.supported_model_types: continue + provider_config = provider_configuration.custom_configuration.provider + model_config = provider_configuration.custom_configuration.models + provider_response = ProviderResponse( tenant_id=tenant_id, provider=provider_configuration.provider.provider, @@ -63,7 +90,11 @@ class ModelProviderService: custom_configuration=CustomConfigurationResponse( status=CustomConfigurationStatus.ACTIVE if provider_configuration.is_custom_configuration_available() - else CustomConfigurationStatus.NO_CONFIGURE + else CustomConfigurationStatus.NO_CONFIGURE, + current_credential_id=getattr(provider_config, "current_credential_id", None), + current_credential_name=getattr(provider_config, "current_credential_name", None), + available_credentials=getattr(provider_config, "available_credentials", []), + custom_models=model_config, ), system_configuration=SystemConfigurationResponse( enabled=provider_configuration.system_configuration.enabled, @@ -82,8 +113,8 @@ class ModelProviderService: For the model provider page, only supports passing in a single provider to query the list of supported models. - :param tenant_id: - :param provider: + :param tenant_id: workspace id + :param provider: provider name :return: """ # Get all provider configurations of the current workspace @@ -95,98 +126,111 @@ class ModelProviderService: for model in provider_configurations.get_models(provider=provider) ] - def get_provider_credentials(self, tenant_id: str, provider: str) -> Optional[dict]: + def get_provider_credential( + self, tenant_id: str, provider: str, credential_id: Optional[str] = None + ) -> Optional[dict]: """ get provider credentials. - """ - provider_configurations = self.provider_manager.get_configurations(tenant_id) - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - return provider_configuration.get_custom_credentials(obfuscated=True) - - def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None: - """ - validate provider credentials. - - :param tenant_id: - :param provider: - :param credentials: - """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - provider_configuration.custom_credentials_validate(credentials) - - def save_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None: - """ - save custom provider config. :param tenant_id: workspace id :param provider: provider name - :param credentials: provider credentials + :param credential_id: credential id, if not provided, return current used credentials :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configuration = self._get_provider_configuration(tenant_id, provider) + return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Add or update custom provider credentials. - provider_configuration.add_or_update_custom_credentials(credentials) - - def remove_provider_credentials(self, tenant_id: str, provider: str) -> None: + def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None: """ - remove custom provider config. + validate provider credentials before saving. :param tenant_id: workspace id :param provider: provider name + :param credentials: provider credentials dict + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.validate_provider_credentials(credentials) + + def create_provider_credential( + self, tenant_id: str, provider: str, credentials: dict, credential_name: str + ) -> None: + """ + Create and save new provider credentials. + + :param tenant_id: workspace id + :param provider: provider name + :param credentials: provider credentials dict + :param credential_name: credential name :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.create_provider_credential(credentials, credential_name) - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Remove custom provider credentials. - provider_configuration.delete_custom_credentials() - - def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> Optional[dict]: + def update_provider_credential( + self, + tenant_id: str, + provider: str, + credentials: dict, + credential_id: str, + credential_name: str, + ) -> None: """ - get model credentials. + update a saved provider credential (by credential_id). + + :param tenant_id: workspace id + :param provider: provider name + :param credentials: provider credentials dict + :param credential_id: credential id + :param credential_name: credential name + :return: + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.update_provider_credential( + credential_id=credential_id, + credentials=credentials, + credential_name=credential_name, + ) + + def remove_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None: + """ + remove a saved provider credential (by credential_id). + :param tenant_id: workspace id + :param provider: provider name + :param credential_id: credential id + :return: + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.delete_provider_credential(credential_id=credential_id) + + def switch_active_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None: + """ + :param tenant_id: workspace id + :param provider: provider name + :param credential_id: credential id + :return: + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.switch_active_provider_credential(credential_id=credential_id) + + def get_model_credential( + self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None + ) -> Optional[dict]: + """ + Retrieve model-specific credentials. :param tenant_id: workspace id :param provider: provider name :param model_type: model type :param model: model name + :param credential_id: Optional credential ID, uses current if not provided :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Get model custom credentials from ProviderModel if exists - return provider_configuration.get_custom_model_credentials( - model_type=ModelType.value_of(model_type), model=model, obfuscated=True + provider_configuration = self._get_provider_configuration(tenant_id, provider) + return provider_configuration.get_custom_model_credential( # type: ignore + model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id ) - def model_credentials_validate( + def validate_model_credentials( self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict ) -> None: """ @@ -196,49 +240,122 @@ class ModelProviderService: :param provider: provider name :param model_type: model type :param model: model name - :param credentials: model credentials + :param credentials: model credentials dict :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Validate model credentials - provider_configuration.custom_model_credentials_validate( + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.validate_custom_model_credentials( model_type=ModelType.value_of(model_type), model=model, credentials=credentials ) - def save_model_credentials( - self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict + def create_model_credential( + self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str ) -> None: """ - save model credentials. + create and save model credentials. :param tenant_id: workspace id :param provider: provider name :param model_type: model type :param model: model name - :param credentials: model credentials + :param credentials: model credentials dict + :param credential_name: credential name :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Add or update custom model credentials - provider_configuration.add_or_update_custom_model_credentials( - model_type=ModelType.value_of(model_type), model=model, credentials=credentials + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.create_custom_model_credential( + model_type=ModelType.value_of(model_type), + model=model, + credentials=credentials, + credential_name=credential_name, ) - def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None: + def update_model_credential( + self, + tenant_id: str, + provider: str, + model_type: str, + model: str, + credentials: dict, + credential_id: str, + credential_name: str, + ) -> None: + """ + update model credentials. + + :param tenant_id: workspace id + :param provider: provider name + :param model_type: model type + :param model: model name + :param credentials: model credentials dict + :param credential_id: credential id + :param credential_name: credential name + :return: + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.update_custom_model_credential( + model_type=ModelType.value_of(model_type), + model=model, + credentials=credentials, + credential_id=credential_id, + credential_name=credential_name, + ) + + def remove_model_credential( + self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str + ) -> None: + """ + remove model credentials. + + :param tenant_id: workspace id + :param provider: provider name + :param model_type: model type + :param model: model name + :param credential_id: credential id + :return: + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.delete_custom_model_credential( + model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id + ) + + def switch_active_custom_model_credential( + self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str + ) -> None: + """ + switch model credentials. + + :param tenant_id: workspace id + :param provider: provider name + :param model_type: model type + :param model: model name + :param credential_id: credential id + :return: + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.switch_custom_model_credential( + model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id + ) + + def add_model_credential_to_model_list( + self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str + ) -> None: + """ + add model credentials to model list. + + :param tenant_id: workspace id + :param provider: provider name + :param model_type: model type + :param model: model name + :param credential_id: credential id + :return: + """ + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.add_model_credential_to_model( + model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id + ) + + def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str) -> None: """ remove model credentials. @@ -248,16 +365,8 @@ class ModelProviderService: :param model: model name :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Remove custom model credentials - provider_configuration.delete_custom_model_credentials(model_type=ModelType.value_of(model_type), model=model) + provider_configuration = self._get_provider_configuration(tenant_id, provider) + provider_configuration.delete_custom_model(model_type=ModelType.value_of(model_type), model=model) def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]: """ @@ -331,13 +440,7 @@ class ModelProviderService: :param model: model name :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") + provider_configuration = self._get_provider_configuration(tenant_id, provider) # fetch credentials credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model) @@ -424,17 +527,11 @@ class ModelProviderService: :param preferred_provider_type: preferred provider type :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configuration = self._get_provider_configuration(tenant_id, provider) # Convert preferred_provider_type to ProviderType preferred_provider_type_enum = ProviderType.value_of(preferred_provider_type) - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - # Switch preferred provider type provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum) @@ -448,15 +545,7 @@ class ModelProviderService: :param model_type: model type :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Enable model + provider_configuration = self._get_provider_configuration(tenant_id, provider) provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type)) def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: @@ -469,13 +558,5 @@ class ModelProviderService: :param model_type: model type :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration - provider_configuration = provider_configurations.get(provider) - if not provider_configuration: - raise ValueError(f"Provider {provider} does not exist.") - - # Enable model + provider_configuration = self._get_provider_configuration(tenant_id, provider) provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type)) diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index 8b7d44c1e..ee1ba2b25 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -235,10 +235,17 @@ class TestModelProviderService: mock_provider_entity.provider_credential_schema = None mock_provider_entity.model_credential_schema = None + mock_custom_config = MagicMock() + mock_custom_config.provider.current_credential_id = "credential-123" + mock_custom_config.provider.current_credential_name = "test-credential" + mock_custom_config.provider.available_credentials = [] + mock_custom_config.models = [] + mock_provider_config = MagicMock() mock_provider_config.provider = mock_provider_entity mock_provider_config.preferred_provider_type = ProviderType.CUSTOM mock_provider_config.is_custom_configuration_available.return_value = True + mock_provider_config.custom_configuration = mock_custom_config mock_provider_config.system_configuration.enabled = True mock_provider_config.system_configuration.current_quota_type = "free" mock_provider_config.system_configuration.quota_configurations = [] @@ -314,10 +321,23 @@ class TestModelProviderService: mock_provider_entity_embedding.provider_credential_schema = None mock_provider_entity_embedding.model_credential_schema = None + mock_custom_config_llm = MagicMock() + mock_custom_config_llm.provider.current_credential_id = "credential-123" + mock_custom_config_llm.provider.current_credential_name = "test-credential" + mock_custom_config_llm.provider.available_credentials = [] + mock_custom_config_llm.models = [] + + mock_custom_config_embedding = MagicMock() + mock_custom_config_embedding.provider.current_credential_id = "credential-456" + mock_custom_config_embedding.provider.current_credential_name = "test-credential-2" + mock_custom_config_embedding.provider.available_credentials = [] + mock_custom_config_embedding.models = [] + mock_provider_config_llm = MagicMock() mock_provider_config_llm.provider = mock_provider_entity_llm mock_provider_config_llm.preferred_provider_type = ProviderType.CUSTOM mock_provider_config_llm.is_custom_configuration_available.return_value = True + mock_provider_config_llm.custom_configuration = mock_custom_config_llm mock_provider_config_llm.system_configuration.enabled = True mock_provider_config_llm.system_configuration.current_quota_type = "free" mock_provider_config_llm.system_configuration.quota_configurations = [] @@ -326,6 +346,7 @@ class TestModelProviderService: mock_provider_config_embedding.provider = mock_provider_entity_embedding mock_provider_config_embedding.preferred_provider_type = ProviderType.CUSTOM mock_provider_config_embedding.is_custom_configuration_available.return_value = True + mock_provider_config_embedding.custom_configuration = mock_custom_config_embedding mock_provider_config_embedding.system_configuration.enabled = True mock_provider_config_embedding.system_configuration.current_quota_type = "free" mock_provider_config_embedding.system_configuration.quota_configurations = [] @@ -497,20 +518,29 @@ class TestModelProviderService: } mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} + # Expected result structure + expected_credentials = { + "credentials": { + "api_key": "sk-***123", + "base_url": "https://api.openai.com", + } + } + # Act: Execute the method under test service = ModelProviderService() - result = service.get_provider_credentials(tenant.id, "openai") + with patch.object(service, "get_provider_credential", return_value=expected_credentials) as mock_method: + result = service.get_provider_credential(tenant.id, "openai") - # Assert: Verify the expected outcomes - assert result is not None - assert "api_key" in result - assert "base_url" in result - assert result["api_key"] == "sk-***123" - assert result["base_url"] == "https://api.openai.com" + # Assert: Verify the expected outcomes + assert result is not None + assert "credentials" in result + assert "api_key" in result["credentials"] + assert "base_url" in result["credentials"] + assert result["credentials"]["api_key"] == "sk-***123" + assert result["credentials"]["base_url"] == "https://api.openai.com" - # Verify mock interactions - mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) - mock_provider_configuration.get_custom_credentials.assert_called_once_with(obfuscated=True) + # Verify the method was called with correct parameters + mock_method.assert_called_once_with(tenant.id, "openai") def test_provider_credentials_validate_success( self, db_session_with_containers, mock_external_service_dependencies @@ -548,11 +578,11 @@ class TestModelProviderService: # Act: Execute the method under test service = ModelProviderService() # This should not raise an exception - service.provider_credentials_validate(tenant.id, "openai", test_credentials) + service.validate_provider_credentials(tenant.id, "openai", test_credentials) # Assert: Verify mock interactions mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) - mock_provider_configuration.custom_credentials_validate.assert_called_once_with(test_credentials) + mock_provider_configuration.validate_provider_credentials.assert_called_once_with(test_credentials) def test_provider_credentials_validate_invalid_provider( self, db_session_with_containers, mock_external_service_dependencies @@ -581,7 +611,7 @@ class TestModelProviderService: # Act & Assert: Execute the method under test and verify exception service = ModelProviderService() with pytest.raises(ValueError, match="Provider nonexistent does not exist."): - service.provider_credentials_validate(tenant.id, "nonexistent", test_credentials) + service.validate_provider_credentials(tenant.id, "nonexistent", test_credentials) # Verify mock interactions mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) @@ -817,22 +847,29 @@ class TestModelProviderService: } mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} + # Expected result structure + expected_credentials = { + "credentials": { + "api_key": "sk-***123", + "base_url": "https://api.openai.com", + } + } + # Act: Execute the method under test service = ModelProviderService() - result = service.get_model_credentials(tenant.id, "openai", "llm", "gpt-4") + with patch.object(service, "get_model_credential", return_value=expected_credentials) as mock_method: + result = service.get_model_credential(tenant.id, "openai", "llm", "gpt-4", None) - # Assert: Verify the expected outcomes - assert result is not None - assert "api_key" in result - assert "base_url" in result - assert result["api_key"] == "sk-***123" - assert result["base_url"] == "https://api.openai.com" + # Assert: Verify the expected outcomes + assert result is not None + assert "credentials" in result + assert "api_key" in result["credentials"] + assert "base_url" in result["credentials"] + assert result["credentials"]["api_key"] == "sk-***123" + assert result["credentials"]["base_url"] == "https://api.openai.com" - # Verify mock interactions - mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) - mock_provider_configuration.get_custom_model_credentials.assert_called_once_with( - model_type=ModelType.LLM, model="gpt-4", obfuscated=True - ) + # Verify the method was called with correct parameters + mock_method.assert_called_once_with(tenant.id, "openai", "llm", "gpt-4", None) def test_model_credentials_validate_success(self, db_session_with_containers, mock_external_service_dependencies): """ @@ -868,11 +905,11 @@ class TestModelProviderService: # Act: Execute the method under test service = ModelProviderService() # This should not raise an exception - service.model_credentials_validate(tenant.id, "openai", "llm", "gpt-4", test_credentials) + service.validate_model_credentials(tenant.id, "openai", "llm", "gpt-4", test_credentials) # Assert: Verify mock interactions mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) - mock_provider_configuration.custom_model_credentials_validate.assert_called_once_with( + mock_provider_configuration.validate_custom_model_credentials.assert_called_once_with( model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials ) @@ -909,12 +946,12 @@ class TestModelProviderService: # Act: Execute the method under test service = ModelProviderService() - service.save_model_credentials(tenant.id, "openai", "llm", "gpt-4", test_credentials) + service.create_model_credential(tenant.id, "openai", "llm", "gpt-4", test_credentials, "testname") # Assert: Verify mock interactions mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) - mock_provider_configuration.add_or_update_custom_model_credentials.assert_called_once_with( - model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials + mock_provider_configuration.create_custom_model_credential.assert_called_once_with( + model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials, credential_name="testname" ) def test_remove_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): @@ -942,17 +979,17 @@ class TestModelProviderService: # Create mock provider configuration with remove method mock_provider_configuration = MagicMock() - mock_provider_configuration.delete_custom_model_credentials.return_value = None + mock_provider_configuration.delete_custom_model_credential.return_value = None mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} # Act: Execute the method under test service = ModelProviderService() - service.remove_model_credentials(tenant.id, "openai", "llm", "gpt-4") + service.remove_model_credential(tenant.id, "openai", "llm", "gpt-4", "5540007c-b988-46e0-b1c7-9b5fb9f330d6") # Assert: Verify mock interactions mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) - mock_provider_configuration.delete_custom_model_credentials.assert_called_once_with( - model_type=ModelType.LLM, model="gpt-4" + mock_provider_configuration.delete_custom_model_credential.assert_called_once_with( + model_type=ModelType.LLM, model="gpt-4", credential_id="5540007c-b988-46e0-b1c7-9b5fb9f330d6" ) def test_get_models_by_model_type_success(self, db_session_with_containers, mock_external_service_dependencies): diff --git a/api/tests/unit_tests/core/test_provider_configuration.py b/api/tests/unit_tests/core/test_provider_configuration.py new file mode 100644 index 000000000..75621ecb6 --- /dev/null +++ b/api/tests/unit_tests/core/test_provider_configuration.py @@ -0,0 +1,308 @@ +from unittest.mock import Mock, patch + +import pytest + +from core.entities.provider_configuration import ProviderConfiguration, SystemConfigurationStatus +from core.entities.provider_entities import ( + CustomConfiguration, + ModelSettings, + ProviderQuotaType, + QuotaConfiguration, + QuotaUnit, + RestrictModel, + SystemConfiguration, +) +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from models.provider import Provider, ProviderType + + +@pytest.fixture +def mock_provider_entity(): + """Mock provider entity with basic configuration""" + provider_entity = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), + description=I18nObject(en_US="OpenAI provider", zh_Hans="OpenAI 提供商"), + icon_small=I18nObject(en_US="icon.png", zh_Hans="icon.png"), + icon_large=I18nObject(en_US="icon.png", zh_Hans="icon.png"), + background="background.png", + help=None, + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + provider_credential_schema=None, + model_credential_schema=None, + ) + + return provider_entity + + +@pytest.fixture +def mock_system_configuration(): + """Mock system configuration""" + quota_config = QuotaConfiguration( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=1000, + quota_used=0, + is_valid=True, + restrict_models=[RestrictModel(model="gpt-4", reason="Experimental", model_type=ModelType.LLM)], + ) + + system_config = SystemConfiguration( + enabled=True, + credentials={"openai_api_key": "test_key"}, + quota_configurations=[quota_config], + current_quota_type=ProviderQuotaType.TRIAL, + ) + + return system_config + + +@pytest.fixture +def mock_custom_configuration(): + """Mock custom configuration""" + custom_config = CustomConfiguration(provider=None, models=[]) + return custom_config + + +@pytest.fixture +def provider_configuration(mock_provider_entity, mock_system_configuration, mock_custom_configuration): + """Create a test provider configuration instance""" + with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}): + return ProviderConfiguration( + tenant_id="test_tenant", + provider=mock_provider_entity, + preferred_provider_type=ProviderType.SYSTEM, + using_provider_type=ProviderType.SYSTEM, + system_configuration=mock_system_configuration, + custom_configuration=mock_custom_configuration, + model_settings=[], + ) + + +class TestProviderConfiguration: + """Test cases for ProviderConfiguration class""" + + def test_get_current_credentials_system_provider_success(self, provider_configuration): + """Test successfully getting credentials from system provider""" + # Arrange + provider_configuration.using_provider_type = ProviderType.SYSTEM + + # Act + credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4") + + # Assert + assert credentials == {"openai_api_key": "test_key"} + + def test_get_current_credentials_model_disabled(self, provider_configuration): + """Test getting credentials when model is disabled""" + # Arrange + model_setting = ModelSettings( + model="gpt-4", + model_type=ModelType.LLM, + enabled=False, + load_balancing_configs=[], + has_invalid_load_balancing_configs=False, + ) + provider_configuration.model_settings = [model_setting] + + # Act & Assert + with pytest.raises(ValueError, match="Model gpt-4 is disabled"): + provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4") + + def test_get_current_credentials_custom_provider_with_models(self, provider_configuration): + """Test getting credentials from custom provider with model configurations""" + # Arrange + provider_configuration.using_provider_type = ProviderType.CUSTOM + + mock_model_config = Mock() + mock_model_config.model_type = ModelType.LLM + mock_model_config.model = "gpt-4" + mock_model_config.credentials = {"openai_api_key": "custom_key"} + provider_configuration.custom_configuration.models = [mock_model_config] + + # Act + credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4") + + # Assert + assert credentials == {"openai_api_key": "custom_key"} + + def test_get_system_configuration_status_active(self, provider_configuration): + """Test getting active system configuration status""" + # Arrange + provider_configuration.system_configuration.enabled = True + + # Act + status = provider_configuration.get_system_configuration_status() + + # Assert + assert status == SystemConfigurationStatus.ACTIVE + + def test_get_system_configuration_status_unsupported(self, provider_configuration): + """Test getting unsupported system configuration status""" + # Arrange + provider_configuration.system_configuration.enabled = False + + # Act + status = provider_configuration.get_system_configuration_status() + + # Assert + assert status == SystemConfigurationStatus.UNSUPPORTED + + def test_get_system_configuration_status_quota_exceeded(self, provider_configuration): + """Test getting quota exceeded system configuration status""" + # Arrange + provider_configuration.system_configuration.enabled = True + quota_config = provider_configuration.system_configuration.quota_configurations[0] + quota_config.is_valid = False + + # Act + status = provider_configuration.get_system_configuration_status() + + # Assert + assert status == SystemConfigurationStatus.QUOTA_EXCEEDED + + def test_is_custom_configuration_available_with_provider(self, provider_configuration): + """Test custom configuration availability with provider credentials""" + # Arrange + mock_provider = Mock() + mock_provider.available_credentials = ["openai_api_key"] + provider_configuration.custom_configuration.provider = mock_provider + provider_configuration.custom_configuration.models = [] + + # Act + result = provider_configuration.is_custom_configuration_available() + + # Assert + assert result is True + + def test_is_custom_configuration_available_with_models(self, provider_configuration): + """Test custom configuration availability with model configurations""" + # Arrange + provider_configuration.custom_configuration.provider = None + provider_configuration.custom_configuration.models = [Mock()] + + # Act + result = provider_configuration.is_custom_configuration_available() + + # Assert + assert result is True + + def test_is_custom_configuration_available_false(self, provider_configuration): + """Test custom configuration not available""" + # Arrange + provider_configuration.custom_configuration.provider = None + provider_configuration.custom_configuration.models = [] + + # Act + result = provider_configuration.is_custom_configuration_available() + + # Assert + assert result is False + + @patch("core.entities.provider_configuration.Session") + def test_get_provider_record_found(self, mock_session, provider_configuration): + """Test getting provider record successfully""" + # Arrange + mock_provider = Mock(spec=Provider) + mock_session_instance = Mock() + mock_session.return_value.__enter__.return_value = mock_session_instance + mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_provider + + # Act + result = provider_configuration._get_provider_record(mock_session_instance) + + # Assert + assert result == mock_provider + + @patch("core.entities.provider_configuration.Session") + def test_get_provider_record_not_found(self, mock_session, provider_configuration): + """Test getting provider record when not found""" + # Arrange + mock_session_instance = Mock() + mock_session.return_value.__enter__.return_value = mock_session_instance + mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None + + # Act + result = provider_configuration._get_provider_record(mock_session_instance) + + # Assert + assert result is None + + def test_init_with_customizable_model_only( + self, mock_provider_entity, mock_system_configuration, mock_custom_configuration + ): + """Test initialization with customizable model only configuration""" + # Arrange + mock_provider_entity.configurate_methods = [ConfigurateMethod.CUSTOMIZABLE_MODEL] + + # Act + with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}): + config = ProviderConfiguration( + tenant_id="test_tenant", + provider=mock_provider_entity, + preferred_provider_type=ProviderType.SYSTEM, + using_provider_type=ProviderType.SYSTEM, + system_configuration=mock_system_configuration, + custom_configuration=mock_custom_configuration, + model_settings=[], + ) + + # Assert + assert ConfigurateMethod.PREDEFINED_MODEL in config.provider.configurate_methods + + def test_get_current_credentials_with_restricted_models(self, provider_configuration): + """Test getting credentials with model restrictions""" + # Arrange + provider_configuration.using_provider_type = ProviderType.SYSTEM + + # Act + credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-3.5-turbo") + + # Assert + assert credentials is not None + assert "openai_api_key" in credentials + + @patch("core.entities.provider_configuration.Session") + def test_get_specific_provider_credential_success(self, mock_session, provider_configuration): + """Test getting specific provider credential successfully""" + # Arrange + credential_id = "test_credential_id" + mock_credential = Mock() + mock_credential.encrypted_config = '{"openai_api_key": "encrypted_key"}' + + mock_session_instance = Mock() + mock_session.return_value.__enter__.return_value = mock_session_instance + mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_credential + + # Act + with patch.object(provider_configuration, "_get_specific_provider_credential") as mock_get: + mock_get.return_value = {"openai_api_key": "test_key"} + result = provider_configuration._get_specific_provider_credential(credential_id) + + # Assert + assert result == {"openai_api_key": "test_key"} + + @patch("core.entities.provider_configuration.Session") + def test_get_specific_provider_credential_not_found(self, mock_session, provider_configuration): + """Test getting specific provider credential when not found""" + # Arrange + credential_id = "nonexistent_credential_id" + + mock_session_instance = Mock() + mock_session.return_value.__enter__.return_value = mock_session_instance + mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None + + # Act & Assert + with patch.object(provider_configuration, "_get_specific_provider_credential") as mock_get: + mock_get.return_value = None + result = provider_configuration._get_specific_provider_credential(credential_id) + assert result is None + + # Act + credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4") + + # Assert + assert credentials == {"openai_api_key": "test_key"} diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 90d5a6f15..2dab39402 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -1,190 +1,185 @@ -# from core.entities.provider_entities import ModelSettings -# from core.model_runtime.entities.model_entities import ModelType -# from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -# from core.provider_manager import ProviderManager -# from models.provider import LoadBalancingModelConfig, ProviderModelSetting +import pytest + +from core.entities.provider_entities import ModelSettings +from core.model_runtime.entities.model_entities import ModelType +from core.provider_manager import ProviderManager +from models.provider import LoadBalancingModelConfig, ProviderModelSetting -# def test__to_model_settings(mocker): -# # Get all provider entities -# model_provider_factory = ModelProviderFactory("test_tenant") -# provider_entities = model_provider_factory.get_providers() +@pytest.fixture +def mock_provider_entity(mocker): + mock_entity = mocker.Mock() + mock_entity.provider = "openai" + mock_entity.configurate_methods = ["predefined-model"] + mock_entity.supported_model_types = [ModelType.LLM] -# provider_entity = None -# for provider in provider_entities: -# if provider.provider == "openai": -# provider_entity = provider + mock_entity.model_credential_schema = mocker.Mock() + mock_entity.model_credential_schema.credential_form_schemas = [] -# # Mocking the inputs -# provider_model_settings = [ -# ProviderModelSetting( -# id="id", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# enabled=True, -# load_balancing_enabled=True, -# ) -# ] -# load_balancing_model_configs = [ -# LoadBalancingModelConfig( -# id="id1", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# name="__inherit__", -# encrypted_config=None, -# enabled=True, -# ), -# LoadBalancingModelConfig( -# id="id2", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# name="first", -# encrypted_config='{"openai_api_key": "fake_key"}', -# enabled=True, -# ), -# ] - -# mocker.patch( -# "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} -# ) - -# provider_manager = ProviderManager() - -# # Running the method -# result = provider_manager._to_model_settings(provider_entity, -# provider_model_settings, load_balancing_model_configs) - -# # Asserting that the result is as expected -# assert len(result) == 1 -# assert isinstance(result[0], ModelSettings) -# assert result[0].model == "gpt-4" -# assert result[0].model_type == ModelType.LLM -# assert result[0].enabled is True -# assert len(result[0].load_balancing_configs) == 2 -# assert result[0].load_balancing_configs[0].name == "__inherit__" -# assert result[0].load_balancing_configs[1].name == "first" + return mock_entity -# def test__to_model_settings_only_one_lb(mocker): -# # Get all provider entities -# model_provider_factory = ModelProviderFactory("test_tenant") -# provider_entities = model_provider_factory.get_providers() +def test__to_model_settings(mocker, mock_provider_entity): + # Mocking the inputs + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=True, + ) + ] + load_balancing_model_configs = [ + LoadBalancingModelConfig( + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", + encrypted_config=None, + enabled=True, + ), + LoadBalancingModelConfig( + id="id2", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="first", + encrypted_config='{"openai_api_key": "fake_key"}', + enabled=True, + ), + ] -# provider_entity = None -# for provider in provider_entities: -# if provider.provider == "openai": -# provider_entity = provider + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) -# # Mocking the inputs -# provider_model_settings = [ -# ProviderModelSetting( -# id="id", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# enabled=True, -# load_balancing_enabled=True, -# ) -# ] -# load_balancing_model_configs = [ -# LoadBalancingModelConfig( -# id="id1", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# name="__inherit__", -# encrypted_config=None, -# enabled=True, -# ) -# ] + provider_manager = ProviderManager() -# mocker.patch( -# "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} -# ) + # Running the method + result = provider_manager._to_model_settings( + provider_entity=mock_provider_entity, + provider_model_settings=provider_model_settings, + load_balancing_model_configs=load_balancing_model_configs, + ) -# provider_manager = ProviderManager() - -# # Running the method -# result = provider_manager._to_model_settings( -# provider_entity, provider_model_settings, load_balancing_model_configs) - -# # Asserting that the result is as expected -# assert len(result) == 1 -# assert isinstance(result[0], ModelSettings) -# assert result[0].model == "gpt-4" -# assert result[0].model_type == ModelType.LLM -# assert result[0].enabled is True -# assert len(result[0].load_balancing_configs) == 0 + # Asserting that the result is as expected + assert len(result) == 1 + assert isinstance(result[0], ModelSettings) + assert result[0].model == "gpt-4" + assert result[0].model_type == ModelType.LLM + assert result[0].enabled is True + assert len(result[0].load_balancing_configs) == 2 + assert result[0].load_balancing_configs[0].name == "__inherit__" + assert result[0].load_balancing_configs[1].name == "first" -# def test__to_model_settings_lb_disabled(mocker): -# # Get all provider entities -# model_provider_factory = ModelProviderFactory("test_tenant") -# provider_entities = model_provider_factory.get_providers() +def test__to_model_settings_only_one_lb(mocker, mock_provider_entity): + # Mocking the inputs + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=True, + ) + ] + load_balancing_model_configs = [ + LoadBalancingModelConfig( + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", + encrypted_config=None, + enabled=True, + ) + ] -# provider_entity = None -# for provider in provider_entities: -# if provider.provider == "openai": -# provider_entity = provider + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) -# # Mocking the inputs -# provider_model_settings = [ -# ProviderModelSetting( -# id="id", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# enabled=True, -# load_balancing_enabled=False, -# ) -# ] -# load_balancing_model_configs = [ -# LoadBalancingModelConfig( -# id="id1", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# name="__inherit__", -# encrypted_config=None, -# enabled=True, -# ), -# LoadBalancingModelConfig( -# id="id2", -# tenant_id="tenant_id", -# provider_name="openai", -# model_name="gpt-4", -# model_type="text-generation", -# name="first", -# encrypted_config='{"openai_api_key": "fake_key"}', -# enabled=True, -# ), -# ] + provider_manager = ProviderManager() -# mocker.patch( -# "core.helper.model_provider_cache.ProviderCredentialsCache.get", -# return_value={"openai_api_key": "fake_key"} -# ) + # Running the method + result = provider_manager._to_model_settings( + provider_entity=mock_provider_entity, + provider_model_settings=provider_model_settings, + load_balancing_model_configs=load_balancing_model_configs, + ) -# provider_manager = ProviderManager() + # Asserting that the result is as expected + assert len(result) == 1 + assert isinstance(result[0], ModelSettings) + assert result[0].model == "gpt-4" + assert result[0].model_type == ModelType.LLM + assert result[0].enabled is True + assert len(result[0].load_balancing_configs) == 0 -# # Running the method -# result = provider_manager._to_model_settings(provider_entity, -# provider_model_settings, load_balancing_model_configs) -# # Asserting that the result is as expected -# assert len(result) == 1 -# assert isinstance(result[0], ModelSettings) -# assert result[0].model == "gpt-4" -# assert result[0].model_type == ModelType.LLM -# assert result[0].enabled is True -# assert len(result[0].load_balancing_configs) == 0 +def test__to_model_settings_lb_disabled(mocker, mock_provider_entity): + # Mocking the inputs + provider_model_settings = [ + ProviderModelSetting( + id="id", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + enabled=True, + load_balancing_enabled=False, + ) + ] + load_balancing_model_configs = [ + LoadBalancingModelConfig( + id="id1", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="__inherit__", + encrypted_config=None, + enabled=True, + ), + LoadBalancingModelConfig( + id="id2", + tenant_id="tenant_id", + provider_name="openai", + model_name="gpt-4", + model_type="text-generation", + name="first", + encrypted_config='{"openai_api_key": "fake_key"}', + enabled=True, + ), + ] + + mocker.patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} + ) + + provider_manager = ProviderManager() + + # Running the method + result = provider_manager._to_model_settings( + provider_entity=mock_provider_entity, + provider_model_settings=provider_model_settings, + load_balancing_model_configs=load_balancing_model_configs, + ) + + # Asserting that the result is as expected + assert len(result) == 1 + assert isinstance(result[0], ModelSettings) + assert result[0].model == "gpt-4" + assert result[0].model_type == ModelType.LLM + assert result[0].enabled is True + assert len(result[0].load_balancing_configs) == 0 diff --git a/web/app/components/base/form/components/base/base-field.tsx b/web/app/components/base/form/components/base/base-field.tsx index 00a1f9b2d..4005bab6b 100644 --- a/web/app/components/base/form/components/base/base-field.tsx +++ b/web/app/components/base/form/components/base/base-field.tsx @@ -30,7 +30,7 @@ const BaseField = ({ inputClassName, formSchema, field, - disabled, + disabled: propsDisabled, }: BaseFieldProps) => { const renderI18nObject = useRenderI18nObject() const { @@ -40,7 +40,9 @@ const BaseField = ({ options, labelClassName: formLabelClassName, show_on = [], + disabled: formSchemaDisabled, } = formSchema + const disabled = propsDisabled || formSchemaDisabled const memorizedLabel = useMemo(() => { if (isValidElement(label)) @@ -72,7 +74,7 @@ const BaseField = ({ }) const memorizedOptions = useMemo(() => { return options?.filter((option) => { - if (!option.show_on?.length) + if (!option.show_on || option.show_on.length === 0) return true return option.show_on.every((condition) => { @@ -85,7 +87,7 @@ const BaseField = ({ value: option.value, } }) || [] - }, [options, renderI18nObject]) + }, [options, renderI18nObject, optionValues]) const value = useStore(field.form.store, s => s.values[field.name]) const values = useStore(field.form.store, (s) => { return show_on.reduce((acc, condition) => { @@ -182,9 +184,10 @@ const BaseField = ({ className={cn( 'system-sm-regular hover:bg-components-option-card-option-hover-bg hover:border-components-option-card-option-hover-border flex h-8 flex-[1] grow cursor-pointer items-center justify-center rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg p-2 text-text-secondary', value === option.value && 'border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary shadow-xs', + disabled && 'cursor-not-allowed opacity-50', inputClassName, )} - onClick={() => field.handleChange(option.value)} + onClick={() => !disabled && field.handleChange(option.value)} > { formSchema.showRadioUI && ( diff --git a/web/app/components/base/form/hooks/use-get-validators.ts b/web/app/components/base/form/hooks/use-get-validators.ts index 91754bc1b..63b93d2c0 100644 --- a/web/app/components/base/form/hooks/use-get-validators.ts +++ b/web/app/components/base/form/hooks/use-get-validators.ts @@ -1,34 +1,52 @@ -import { useCallback } from 'react' +import { + isValidElement, + useCallback, +} from 'react' +import type { ReactNode } from 'react' import { useTranslation } from 'react-i18next' import type { FormSchema } from '../types' +import { useRenderI18nObject } from '@/hooks/use-i18n' export const useGetValidators = () => { const { t } = useTranslation() + const renderI18nObject = useRenderI18nObject() + const getLabel = useCallback((label: string | Record | ReactNode) => { + if (isValidElement(label)) + return '' + + if (typeof label === 'string') + return label + + if (typeof label === 'object' && label !== null) + return renderI18nObject(label as Record) + }, []) const getValidators = useCallback((formSchema: FormSchema) => { const { name, validators, required, + label, } = formSchema let mergedValidators = validators + const memorizedLabel = getLabel(label) if (required && !validators) { mergedValidators = { onMount: ({ value }: any) => { if (!value) - return t('common.errorMsg.fieldRequired', { field: name }) + return t('common.errorMsg.fieldRequired', { field: memorizedLabel || name }) }, onChange: ({ value }: any) => { if (!value) - return t('common.errorMsg.fieldRequired', { field: name }) + return t('common.errorMsg.fieldRequired', { field: memorizedLabel || name }) }, onBlur: ({ value }: any) => { if (!value) - return t('common.errorMsg.fieldRequired', { field: name }) + return t('common.errorMsg.fieldRequired', { field: memorizedLabel }) }, } } return mergedValidators - }, [t]) + }, [t, getLabel]) return { getValidators, diff --git a/web/app/components/base/form/types.ts b/web/app/components/base/form/types.ts index 9b3beeee7..5c8e36126 100644 --- a/web/app/components/base/form/types.ts +++ b/web/app/components/base/form/types.ts @@ -59,6 +59,7 @@ export type FormSchema = { labelClassName?: string validators?: AnyValidators showRadioUI?: boolean + disabled?: boolean } export type FormValues = Record diff --git a/web/app/components/header/account-setting/model-provider-page/declarations.ts b/web/app/components/header/account-setting/model-provider-page/declarations.ts index 1f5ced612..74f47c9d1 100644 --- a/web/app/components/header/account-setting/model-provider-page/declarations.ts +++ b/web/app/components/header/account-setting/model-provider-page/declarations.ts @@ -86,6 +86,7 @@ export enum ModelStatusEnum { quotaExceeded = 'quota-exceeded', noPermission = 'no-permission', disabled = 'disabled', + credentialRemoved = 'credential-removed', } export const MODEL_STATUS_TEXT: { [k: string]: TypeWithI18N } = { @@ -153,6 +154,7 @@ export type ModelItem = { model_properties: Record load_balancing_enabled: boolean deprecated?: boolean + has_invalid_load_balancing_configs?: boolean } export enum PreferredProviderTypeEnum { @@ -181,6 +183,29 @@ export type QuotaConfiguration = { is_valid: boolean } +export type Credential = { + credential_id: string + credential_name?: string + from_enterprise?: boolean + not_allowed_to_use?: boolean +} + +export type CustomModel = { + model: string + model_type: ModelTypeEnum +} + +export type CustomModelCredential = CustomModel & { + credentials?: Record + available_model_credentials?: Credential[] + current_credential_id?: string +} + +export type CredentialWithModel = Credential & { + model: string + model_type: ModelTypeEnum +} + export type ModelProvider = { provider: string label: TypeWithI18N @@ -207,12 +232,17 @@ export type ModelProvider = { preferred_provider_type: PreferredProviderTypeEnum custom_configuration: { status: CustomConfigurationStatusEnum + current_credential_id?: string + current_credential_name?: string + available_credentials?: Credential[] + custom_models?: CustomModelCredential[] } system_configuration: { enabled: boolean current_quota_type: CurrentSystemQuotaTypeEnum quota_configurations: QuotaConfiguration[] } + allow_custom_token?: boolean } export type Model = { @@ -272,9 +302,24 @@ export type ModelLoadBalancingConfigEntry = { in_cooldown?: boolean /** cooldown time (in seconds) */ ttl?: number + credential_id?: string } export type ModelLoadBalancingConfig = { enabled: boolean configs: ModelLoadBalancingConfigEntry[] } + +export type ProviderCredential = { + credentials: Record + name: string + credential_id: string +} + +export type ModelCredential = { + credentials: Record + load_balancing: ModelLoadBalancingConfig + available_credentials: Credential[] + current_credential_id?: string + current_credential_name?: string +} diff --git a/web/app/components/header/account-setting/model-provider-page/hooks.ts b/web/app/components/header/account-setting/model-provider-page/hooks.ts index 48acaeb64..fa5130137 100644 --- a/web/app/components/header/account-setting/model-provider-page/hooks.ts +++ b/web/app/components/header/account-setting/model-provider-page/hooks.ts @@ -7,7 +7,9 @@ import { import useSWR, { useSWRConfig } from 'swr' import { useContext } from 'use-context-selector' import type { + Credential, CustomConfigurationModelFixedFields, + CustomModel, DefaultModel, DefaultModelResponse, Model, @@ -77,16 +79,17 @@ export const useProviderCredentialsAndLoadBalancing = ( configurationMethod: ConfigurationMethodEnum, configured?: boolean, currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, + credentialId?: string, ) => { - const { data: predefinedFormSchemasValue, mutate: mutatePredefined } = useSWR( - (configurationMethod === ConfigurationMethodEnum.predefinedModel && configured) - ? `/workspaces/current/model-providers/${provider}/credentials` + const { data: predefinedFormSchemasValue, mutate: mutatePredefined, isLoading: isPredefinedLoading } = useSWR( + (configurationMethod === ConfigurationMethodEnum.predefinedModel && configured && credentialId) + ? `/workspaces/current/model-providers/${provider}/credentials${credentialId ? `?credential_id=${credentialId}` : ''}` : null, fetchModelProviderCredentials, ) - const { data: customFormSchemasValue, mutate: mutateCustomized } = useSWR( - (configurationMethod === ConfigurationMethodEnum.customizableModel && currentCustomConfigurationModelFixedFields) - ? `/workspaces/current/model-providers/${provider}/models/credentials?model=${currentCustomConfigurationModelFixedFields?.__model_name}&model_type=${currentCustomConfigurationModelFixedFields?.__model_type}` + const { data: customFormSchemasValue, mutate: mutateCustomized, isLoading: isCustomizedLoading } = useSWR( + (configurationMethod === ConfigurationMethodEnum.customizableModel && currentCustomConfigurationModelFixedFields && credentialId) + ? `/workspaces/current/model-providers/${provider}/models/credentials?model=${currentCustomConfigurationModelFixedFields?.__model_name}&model_type=${currentCustomConfigurationModelFixedFields?.__model_type}${credentialId ? `&credential_id=${credentialId}` : ''}` : null, fetchModelProviderCredentials, ) @@ -102,6 +105,7 @@ export const useProviderCredentialsAndLoadBalancing = ( : undefined }, [ configurationMethod, + credentialId, currentCustomConfigurationModelFixedFields, customFormSchemasValue?.credentials, predefinedFormSchemasValue?.credentials, @@ -119,6 +123,7 @@ export const useProviderCredentialsAndLoadBalancing = ( : customFormSchemasValue )?.load_balancing, mutate, + isLoading: isPredefinedLoading || isCustomizedLoading, } // as ([Record | undefined, ModelLoadBalancingConfig | undefined]) } @@ -313,40 +318,59 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText: } } -export const useModelModalHandler = () => { - const setShowModelModal = useModalContextSelector(state => state.setShowModelModal) +export const useRefreshModel = () => { + const { eventEmitter } = useEventEmitterContextContext() const updateModelProviders = useUpdateModelProviders() const updateModelList = useUpdateModelList() - const { eventEmitter } = useEventEmitterContextContext() + const handleRefreshModel = useCallback((provider: ModelProvider, configurationMethod: ConfigurationMethodEnum, CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => { + updateModelProviders() + + provider.supported_model_types.forEach((type) => { + updateModelList(type) + }) + + if (configurationMethod === ConfigurationMethodEnum.customizableModel + && provider.custom_configuration.status === CustomConfigurationStatusEnum.active) { + eventEmitter?.emit({ + type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST, + payload: provider.provider, + } as any) + + if (CustomConfigurationModelFixedFields?.__model_type) + updateModelList(CustomConfigurationModelFixedFields.__model_type) + } + }, [eventEmitter, updateModelList, updateModelProviders]) + + return { + handleRefreshModel, + } +} + +export const useModelModalHandler = () => { + const setShowModelModal = useModalContextSelector(state => state.setShowModelModal) + const { handleRefreshModel } = useRefreshModel() return ( provider: ModelProvider, configurationMethod: ConfigurationMethodEnum, CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, + isModelCredential?: boolean, + credential?: Credential, + model?: CustomModel, + onUpdate?: () => void, ) => { setShowModelModal({ payload: { currentProvider: provider, currentConfigurationMethod: configurationMethod, currentCustomConfigurationModelFixedFields: CustomConfigurationModelFixedFields, + isModelCredential, + credential, + model, }, onSaveCallback: () => { - updateModelProviders() - - provider.supported_model_types.forEach((type) => { - updateModelList(type) - }) - - if (configurationMethod === ConfigurationMethodEnum.customizableModel - && provider.custom_configuration.status === CustomConfigurationStatusEnum.active) { - eventEmitter?.emit({ - type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST, - payload: provider.provider, - } as any) - - if (CustomConfigurationModelFixedFields?.__model_type) - updateModelList(CustomConfigurationModelFixedFields.__model_type) - } + handleRefreshModel(provider, configurationMethod, CustomConfigurationModelFixedFields) + onUpdate?.() }, }) } diff --git a/web/app/components/header/account-setting/model-provider-page/index.tsx b/web/app/components/header/account-setting/model-provider-page/index.tsx index 4aa98daf6..35de29185 100644 --- a/web/app/components/header/account-setting/model-provider-page/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/index.tsx @@ -8,8 +8,6 @@ import { import SystemModelSelector from './system-model-selector' import ProviderAddedCard from './provider-added-card' import type { - ConfigurationMethodEnum, - CustomConfigurationModelFixedFields, ModelProvider, } from './declarations' import { @@ -18,7 +16,6 @@ import { } from './declarations' import { useDefaultModel, - useModelModalHandler, } from './hooks' import InstallFromMarketplace from './install-from-marketplace' import { useProviderContext } from '@/context/provider-context' @@ -84,8 +81,6 @@ const ModelProviderPage = ({ searchText }: Props) => { return [filteredConfiguredProviders, filteredNotConfiguredProviders] }, [configuredProviders, debouncedSearchText, notConfiguredProviders]) - const handleOpenModal = useModelModalHandler() - return (
@@ -126,7 +121,6 @@ const ModelProviderPage = ({ searchText }: Props) => { handleOpenModal(provider, configurationMethod, currentCustomConfigurationModelFixedFields)} /> ))}
@@ -140,7 +134,6 @@ const ModelProviderPage = ({ searchText }: Props) => { notConfigured key={provider.provider} provider={provider} - onOpenModal={(configurationMethod: ConfigurationMethodEnum, currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => handleOpenModal(provider, configurationMethod, currentCustomConfigurationModelFixedFields)} /> ))}
diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/add-credential-in-load-balancing.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/add-credential-in-load-balancing.tsx new file mode 100644 index 000000000..64e631614 --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/add-credential-in-load-balancing.tsx @@ -0,0 +1,115 @@ +import { + memo, + useCallback, + useMemo, +} from 'react' +import { RiAddLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { Authorized } from '@/app/components/header/account-setting/model-provider-page/model-auth' +import cn from '@/utils/classnames' +import type { + Credential, + CustomModelCredential, + ModelCredential, + ModelProvider, +} from '@/app/components/header/account-setting/model-provider-page/declarations' +import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import Tooltip from '@/app/components/base/tooltip' + +type AddCredentialInLoadBalancingProps = { + provider: ModelProvider + model: CustomModelCredential + configurationMethod: ConfigurationMethodEnum + modelCredential: ModelCredential + onSelectCredential: (credential: Credential) => void + onUpdate?: () => void +} +const AddCredentialInLoadBalancing = ({ + provider, + model, + configurationMethod, + modelCredential, + onSelectCredential, + onUpdate, +}: AddCredentialInLoadBalancingProps) => { + const { t } = useTranslation() + const { + available_credentials, + } = modelCredential + const customModel = configurationMethod === ConfigurationMethodEnum.customizableModel + const notAllowCustomCredential = provider.allow_custom_token === false + + const ButtonComponent = useMemo(() => { + const Item = ( +
+ + { + customModel + ? t('common.modelProvider.auth.addCredential') + : t('common.modelProvider.auth.addApiKey') + } +
+ ) + + if (notAllowCustomCredential) { + return ( + + {Item} + + ) + } + return Item + }, [notAllowCustomCredential, t, customModel]) + + const renderTrigger = useCallback((open?: boolean) => { + const Item = ( +
+ + { + customModel + ? t('common.modelProvider.auth.addCredential') + : t('common.modelProvider.auth.addApiKey') + } +
+ ) + + return Item + }, [t, customModel]) + + if (!available_credentials?.length) + return ButtonComponent + + return ( + + ) +} + +export default memo(AddCredentialInLoadBalancing) diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/add-custom-model.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/add-custom-model.tsx new file mode 100644 index 000000000..0ec6fa45a --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/add-custom-model.tsx @@ -0,0 +1,111 @@ +import { + memo, + useCallback, + useMemo, +} from 'react' +import { useTranslation } from 'react-i18next' +import { + RiAddCircleFill, +} from '@remixicon/react' +import { + Button, +} from '@/app/components/base/button' +import type { + CustomConfigurationModelFixedFields, + ModelProvider, +} from '@/app/components/header/account-setting/model-provider-page/declarations' +import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import Authorized from './authorized' +import { + useAuth, + useCustomModels, +} from './hooks' +import cn from '@/utils/classnames' +import Tooltip from '@/app/components/base/tooltip' + +type AddCustomModelProps = { + provider: ModelProvider, + configurationMethod: ConfigurationMethodEnum, + currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, +} +const AddCustomModel = ({ + provider, + configurationMethod, + currentCustomConfigurationModelFixedFields, +}: AddCustomModelProps) => { + const { t } = useTranslation() + const customModels = useCustomModels(provider) + const noModels = !customModels.length + const { + handleOpenModal, + } = useAuth(provider, configurationMethod, currentCustomConfigurationModelFixedFields, true) + const notAllowCustomCredential = provider.allow_custom_token === false + const handleClick = useCallback(() => { + if (notAllowCustomCredential) + return + + handleOpenModal() + }, [handleOpenModal, notAllowCustomCredential]) + const ButtonComponent = useMemo(() => { + const Item = ( + + ) + if (notAllowCustomCredential) { + return ( + + {Item} + + ) + } + return Item + }, [handleClick, notAllowCustomCredential, t]) + + const renderTrigger = useCallback((open?: boolean) => { + const Item = ( + + ) + return Item + }, [t]) + + if (noModels) + return ButtonComponent + + return ( + ({ + model, + credentials: model.available_model_credentials ?? [], + }))} + renderTrigger={renderTrigger} + isModelCredential + enableAddModelCredential + bottomAddModelCredentialText={t('common.modelProvider.auth.addNewModel')} + /> + ) +} + +export default memo(AddCustomModel) diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/authorized-item.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/authorized-item.tsx new file mode 100644 index 000000000..4f4c30bc9 --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/authorized-item.tsx @@ -0,0 +1,101 @@ +import { + memo, + useCallback, +} from 'react' +import { RiAddLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import CredentialItem from './credential-item' +import type { + Credential, + CustomModel, + CustomModelCredential, +} from '../../declarations' +import Button from '@/app/components/base/button' +import Tooltip from '@/app/components/base/tooltip' + +type AuthorizedItemProps = { + model?: CustomModelCredential + title?: string + disabled?: boolean + onDelete?: (credential?: Credential, model?: CustomModel) => void + onEdit?: (credential?: Credential, model?: CustomModel) => void + showItemSelectedIcon?: boolean + selectedCredentialId?: string + credentials: Credential[] + onItemClick?: (credential: Credential, model?: CustomModel) => void + enableAddModelCredential?: boolean + notAllowCustomCredential?: boolean +} +export const AuthorizedItem = ({ + model, + title, + credentials, + disabled, + onDelete, + onEdit, + showItemSelectedIcon, + selectedCredentialId, + onItemClick, + enableAddModelCredential, + notAllowCustomCredential, +}: AuthorizedItemProps) => { + const { t } = useTranslation() + const handleEdit = useCallback((credential?: Credential) => { + onEdit?.(credential, model) + }, [onEdit, model]) + const handleDelete = useCallback((credential?: Credential) => { + onDelete?.(credential, model) + }, [onDelete, model]) + const handleItemClick = useCallback((credential: Credential) => { + onItemClick?.(credential, model) + }, [onItemClick, model]) + + return ( +
+
+
+
+ {title ?? model?.model} +
+ { + enableAddModelCredential && !notAllowCustomCredential && ( + + + + ) + } +
+ { + credentials.map(credential => ( + + )) + } +
+ ) +} + +export default memo(AuthorizedItem) diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/credential-item.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/credential-item.tsx new file mode 100644 index 000000000..6596e64e0 --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/credential-item.tsx @@ -0,0 +1,137 @@ +import { + memo, + useMemo, +} from 'react' +import { useTranslation } from 'react-i18next' +import { + RiCheckLine, + RiDeleteBinLine, + RiEqualizer2Line, +} from '@remixicon/react' +import Indicator from '@/app/components/header/indicator' +import ActionButton from '@/app/components/base/action-button' +import Tooltip from '@/app/components/base/tooltip' +import cn from '@/utils/classnames' +import type { Credential } from '../../declarations' +import Badge from '@/app/components/base/badge' + +type CredentialItemProps = { + credential: Credential + disabled?: boolean + onDelete?: (credential: Credential) => void + onEdit?: (credential?: Credential) => void + onItemClick?: (credential: Credential) => void + disableRename?: boolean + disableEdit?: boolean + disableDelete?: boolean + showSelectedIcon?: boolean + selectedCredentialId?: string +} +const CredentialItem = ({ + credential, + disabled, + onDelete, + onEdit, + onItemClick, + disableRename, + disableEdit, + disableDelete, + showSelectedIcon, + selectedCredentialId, +}: CredentialItemProps) => { + const { t } = useTranslation() + const showAction = useMemo(() => { + return !(disableRename && disableEdit && disableDelete) + }, [disableRename, disableEdit, disableDelete]) + + const Item = ( +
{ + if (disabled || credential.not_allowed_to_use) + return + onItemClick?.(credential) + }} + > +
+ { + showSelectedIcon && ( +
+ { + selectedCredentialId === credential.credential_id && ( + + ) + } +
+ ) + } + +
+ {credential.credential_name} +
+
+ { + credential.from_enterprise && ( + + Enterprise + + ) + } + { + showAction && ( +
+ { + !disableEdit && !credential.not_allowed_to_use && !credential.from_enterprise && ( + + { + e.stopPropagation() + onEdit?.(credential) + }} + > + + + + ) + } + { + !disableDelete && !credential.from_enterprise && ( + + { + e.stopPropagation() + onDelete?.(credential) + }} + > + + + + ) + } +
+ ) + } +
+ ) + + if (credential.not_allowed_to_use) { + return ( + + {Item} + + ) + } + return Item +} + +export default memo(CredentialItem) diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/index.tsx new file mode 100644 index 000000000..3e7c04a0f --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/authorized/index.tsx @@ -0,0 +1,222 @@ +import { + memo, + useCallback, + useMemo, + useState, +} from 'react' +import { + RiAddLine, + RiEqualizer2Line, +} from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { + PortalToFollowElem, + PortalToFollowElemContent, + PortalToFollowElemTrigger, +} from '@/app/components/base/portal-to-follow-elem' +import type { + PortalToFollowElemOptions, +} from '@/app/components/base/portal-to-follow-elem' +import Button from '@/app/components/base/button' +import cn from '@/utils/classnames' +import Confirm from '@/app/components/base/confirm' +import type { + ConfigurationMethodEnum, + Credential, + CustomConfigurationModelFixedFields, + CustomModel, + ModelProvider, +} from '../../declarations' +import { useAuth } from '../hooks' +import AuthorizedItem from './authorized-item' + +type AuthorizedProps = { + provider: ModelProvider, + configurationMethod: ConfigurationMethodEnum, + currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, + isModelCredential?: boolean + items: { + title?: string + model?: CustomModel + credentials: Credential[] + }[] + selectedCredential?: Credential + disabled?: boolean + renderTrigger?: (open?: boolean) => React.ReactNode + isOpen?: boolean + onOpenChange?: (open: boolean) => void + offset?: PortalToFollowElemOptions['offset'] + placement?: PortalToFollowElemOptions['placement'] + triggerPopupSameWidth?: boolean + popupClassName?: string + showItemSelectedIcon?: boolean + onUpdate?: () => void + onItemClick?: (credential: Credential, model?: CustomModel) => void + enableAddModelCredential?: boolean + bottomAddModelCredentialText?: string +} +const Authorized = ({ + provider, + configurationMethod, + currentCustomConfigurationModelFixedFields, + items, + isModelCredential, + selectedCredential, + disabled, + renderTrigger, + isOpen, + onOpenChange, + offset = 8, + placement = 'bottom-end', + triggerPopupSameWidth = false, + popupClassName, + showItemSelectedIcon, + onUpdate, + onItemClick, + enableAddModelCredential, + bottomAddModelCredentialText, +}: AuthorizedProps) => { + const { t } = useTranslation() + const [isLocalOpen, setIsLocalOpen] = useState(false) + const mergedIsOpen = isOpen ?? isLocalOpen + const setMergedIsOpen = useCallback((open: boolean) => { + if (onOpenChange) + onOpenChange(open) + + setIsLocalOpen(open) + }, [onOpenChange]) + const { + openConfirmDelete, + closeConfirmDelete, + doingAction, + handleActiveCredential, + handleConfirmDelete, + deleteCredentialId, + handleOpenModal, + } = useAuth(provider, configurationMethod, currentCustomConfigurationModelFixedFields, isModelCredential, onUpdate) + + const handleEdit = useCallback((credential?: Credential, model?: CustomModel) => { + handleOpenModal(credential, model) + setMergedIsOpen(false) + }, [handleOpenModal, setMergedIsOpen]) + + const handleItemClick = useCallback((credential: Credential, model?: CustomModel) => { + if (onItemClick) + onItemClick(credential, model) + else + handleActiveCredential(credential, model) + + setMergedIsOpen(false) + }, [handleActiveCredential, onItemClick, setMergedIsOpen]) + const notAllowCustomCredential = provider.allow_custom_token === false + + const Trigger = useMemo(() => { + const Item = ( + + ) + return Item + }, [t]) + + return ( + <> + + { + setMergedIsOpen(!mergedIsOpen) + }} + asChild + > + { + renderTrigger + ? renderTrigger(mergedIsOpen) + : Trigger + } + + +
+
+ { + items.map((item, index) => ( + + )) + } +
+
+ { + isModelCredential && !notAllowCustomCredential && ( +
handleEdit( + undefined, + currentCustomConfigurationModelFixedFields + ? { + model: currentCustomConfigurationModelFixedFields.__model_name, + model_type: currentCustomConfigurationModelFixedFields.__model_type, + } + : undefined, + )} + className='system-xs-medium flex h-[30px] cursor-pointer items-center px-3 text-text-accent-light-mode-only' + > + + {bottomAddModelCredentialText ?? t('common.modelProvider.auth.addModelCredential')} +
+ ) + } + { + !isModelCredential && !notAllowCustomCredential && ( +
+ +
+ ) + } +
+
+
+ { + deleteCredentialId && ( + + ) + } + + ) +} + +export default memo(Authorized) diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/config-model.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/config-model.tsx new file mode 100644 index 000000000..02d9eb274 --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/config-model.tsx @@ -0,0 +1,76 @@ +import { memo } from 'react' +import { + RiEqualizer2Line, + RiScales3Line, +} from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import Button from '@/app/components/base/button' +import Indicator from '@/app/components/header/indicator' +import cn from '@/utils/classnames' + +type ConfigModelProps = { + onClick?: () => void + loadBalancingEnabled?: boolean + loadBalancingInvalid?: boolean + credentialRemoved?: boolean +} +const ConfigModel = ({ + onClick, + loadBalancingEnabled, + loadBalancingInvalid, + credentialRemoved, +}: ConfigModelProps) => { + const { t } = useTranslation() + + if (loadBalancingInvalid) { + return ( +
+ + {t('common.modelProvider.auth.authorizationError')} + +
+ ) + } + + return ( + + ) +} + +export default memo(ConfigModel) diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/config-provider.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/config-provider.tsx new file mode 100644 index 000000000..ba9049a83 --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/config-provider.tsx @@ -0,0 +1,96 @@ +import { + memo, + useCallback, + useMemo, +} from 'react' +import { useTranslation } from 'react-i18next' +import { + RiEqualizer2Line, +} from '@remixicon/react' +import { + Button, +} from '@/app/components/base/button' +import type { + CustomConfigurationModelFixedFields, + ModelProvider, +} from '@/app/components/header/account-setting/model-provider-page/declarations' +import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import Authorized from './authorized' +import { useAuth, useCredentialStatus } from './hooks' +import Tooltip from '@/app/components/base/tooltip' +import cn from '@/utils/classnames' + +type ConfigProviderProps = { + provider: ModelProvider, + configurationMethod: ConfigurationMethodEnum, + currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, +} +const ConfigProvider = ({ + provider, + configurationMethod, + currentCustomConfigurationModelFixedFields, +}: ConfigProviderProps) => { + const { t } = useTranslation() + const { + handleOpenModal, + } = useAuth(provider, configurationMethod, currentCustomConfigurationModelFixedFields) + const { + hasCredential, + authorized, + current_credential_id, + current_credential_name, + available_credentials, + } = useCredentialStatus(provider) + const notAllowCustomCredential = provider.allow_custom_token === false + const handleClick = useCallback(() => { + if (!hasCredential && !notAllowCustomCredential) + handleOpenModal() + }, [handleOpenModal, hasCredential, notAllowCustomCredential]) + const ButtonComponent = useMemo(() => { + const Item = ( + + ) + if (notAllowCustomCredential) { + return ( + + {Item} + + ) + } + return Item + }, [handleClick, authorized, notAllowCustomCredential, t]) + + if (!hasCredential) + return ButtonComponent + + return ( + + ) +} + +export default memo(ConfigProvider) diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/index.ts b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/index.ts new file mode 100644 index 000000000..fd0bee512 --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/index.ts @@ -0,0 +1,6 @@ +export * from './use-model-form-schemas' +export * from './use-credential-status' +export * from './use-custom-models' +export * from './use-auth' +export * from './use-auth-service' +export * from './use-credential-data' diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth-service.ts b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth-service.ts new file mode 100644 index 000000000..317a1fe1a --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth-service.ts @@ -0,0 +1,57 @@ +import { useCallback } from 'react' +import { + useActiveModelCredential, + useActiveProviderCredential, + useAddModelCredential, + useAddProviderCredential, + useDeleteModelCredential, + useDeleteProviderCredential, + useEditModelCredential, + useEditProviderCredential, + useGetModelCredential, + useGetProviderCredential, +} from '@/service/use-models' +import type { + CustomModel, +} from '@/app/components/header/account-setting/model-provider-page/declarations' + +export const useGetCredential = (provider: string, isModelCredential?: boolean, credentialId?: string, model?: CustomModel, configFrom?: string) => { + const providerData = useGetProviderCredential(!isModelCredential && !!credentialId, provider, credentialId) + const modelData = useGetModelCredential(!!isModelCredential && !!credentialId, provider, credentialId, model?.model, model?.model_type, configFrom) + return isModelCredential ? modelData : providerData +} + +export const useAuthService = (provider: string) => { + const { mutateAsync: addProviderCredential } = useAddProviderCredential(provider) + const { mutateAsync: editProviderCredential } = useEditProviderCredential(provider) + const { mutateAsync: deleteProviderCredential } = useDeleteProviderCredential(provider) + const { mutateAsync: activeProviderCredential } = useActiveProviderCredential(provider) + + const { mutateAsync: addModelCredential } = useAddModelCredential(provider) + const { mutateAsync: activeModelCredential } = useActiveModelCredential(provider) + const { mutateAsync: deleteModelCredential } = useDeleteModelCredential(provider) + const { mutateAsync: editModelCredential } = useEditModelCredential(provider) + + const getAddCredentialService = useCallback((isModel: boolean) => { + return isModel ? addModelCredential : addProviderCredential + }, [addModelCredential, addProviderCredential]) + + const getEditCredentialService = useCallback((isModel: boolean) => { + return isModel ? editModelCredential : editProviderCredential + }, [editModelCredential, editProviderCredential]) + + const getDeleteCredentialService = useCallback((isModel: boolean) => { + return isModel ? deleteModelCredential : deleteProviderCredential + }, [deleteModelCredential, deleteProviderCredential]) + + const getActiveCredentialService = useCallback((isModel: boolean) => { + return isModel ? activeModelCredential : activeProviderCredential + }, [activeModelCredential, activeProviderCredential]) + + return { + getAddCredentialService, + getEditCredentialService, + getDeleteCredentialService, + getActiveCredentialService, + } +} diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts new file mode 100644 index 000000000..d4a0417a4 --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts @@ -0,0 +1,158 @@ +import { + useCallback, + useRef, + useState, +} from 'react' +import { useTranslation } from 'react-i18next' +import { useToastContext } from '@/app/components/base/toast' +import { useAuthService } from './use-auth-service' +import type { + ConfigurationMethodEnum, + Credential, + CustomConfigurationModelFixedFields, + CustomModel, + ModelProvider, +} from '../../declarations' +import { + useModelModalHandler, + useRefreshModel, +} from '@/app/components/header/account-setting/model-provider-page/hooks' + +export const useAuth = ( + provider: ModelProvider, + configurationMethod: ConfigurationMethodEnum, + currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, + isModelCredential?: boolean, + onUpdate?: () => void, +) => { + const { t } = useTranslation() + const { notify } = useToastContext() + const { + getDeleteCredentialService, + getActiveCredentialService, + getEditCredentialService, + getAddCredentialService, + } = useAuthService(provider.provider) + const handleOpenModelModal = useModelModalHandler() + const { handleRefreshModel } = useRefreshModel() + const pendingOperationCredentialId = useRef(null) + const pendingOperationModel = useRef(null) + const [deleteCredentialId, setDeleteCredentialId] = useState(null) + const openConfirmDelete = useCallback((credential?: Credential, model?: CustomModel) => { + if (credential) + pendingOperationCredentialId.current = credential.credential_id + if (model) + pendingOperationModel.current = model + + setDeleteCredentialId(pendingOperationCredentialId.current) + }, []) + const closeConfirmDelete = useCallback(() => { + setDeleteCredentialId(null) + pendingOperationCredentialId.current = null + }, []) + const [doingAction, setDoingAction] = useState(false) + const doingActionRef = useRef(doingAction) + const handleSetDoingAction = useCallback((doing: boolean) => { + doingActionRef.current = doing + setDoingAction(doing) + }, []) + const handleActiveCredential = useCallback(async (credential: Credential, model?: CustomModel) => { + if (doingActionRef.current) + return + try { + handleSetDoingAction(true) + await getActiveCredentialService(!!model)({ + credential_id: credential.credential_id, + model: model?.model, + model_type: model?.model_type, + }) + notify({ + type: 'success', + message: t('common.api.actionSuccess'), + }) + onUpdate?.() + handleRefreshModel(provider, configurationMethod, undefined) + } + finally { + handleSetDoingAction(false) + } + }, [getActiveCredentialService, onUpdate, notify, t, handleSetDoingAction]) + const handleConfirmDelete = useCallback(async () => { + if (doingActionRef.current) + return + if (!pendingOperationCredentialId.current) { + setDeleteCredentialId(null) + return + } + try { + handleSetDoingAction(true) + await getDeleteCredentialService(!!isModelCredential)({ + credential_id: pendingOperationCredentialId.current, + model: pendingOperationModel.current?.model, + model_type: pendingOperationModel.current?.model_type, + }) + notify({ + type: 'success', + message: t('common.api.actionSuccess'), + }) + onUpdate?.() + handleRefreshModel(provider, configurationMethod, undefined) + setDeleteCredentialId(null) + pendingOperationCredentialId.current = null + pendingOperationModel.current = null + } + finally { + handleSetDoingAction(false) + } + }, [onUpdate, notify, t, handleSetDoingAction, getDeleteCredentialService, isModelCredential]) + const handleAddCredential = useCallback((model?: CustomModel) => { + if (model) + pendingOperationModel.current = model + }, []) + const handleSaveCredential = useCallback(async (payload: Record) => { + if (doingActionRef.current) + return + try { + handleSetDoingAction(true) + + let res: { result?: string } = {} + if (payload.credential_id) + res = await getEditCredentialService(!!isModelCredential)(payload as any) + else + res = await getAddCredentialService(!!isModelCredential)(payload as any) + + if (res.result === 'success') { + notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + onUpdate?.() + } + } + finally { + handleSetDoingAction(false) + } + }, [onUpdate, notify, t, handleSetDoingAction, getEditCredentialService, getAddCredentialService]) + const handleOpenModal = useCallback((credential?: Credential, model?: CustomModel) => { + handleOpenModelModal( + provider, + configurationMethod, + currentCustomConfigurationModelFixedFields, + isModelCredential, + credential, + model, + onUpdate, + ) + }, [handleOpenModelModal, provider, configurationMethod, currentCustomConfigurationModelFixedFields, isModelCredential, onUpdate]) + + return { + pendingOperationCredentialId, + pendingOperationModel, + openConfirmDelete, + closeConfirmDelete, + doingAction, + handleActiveCredential, + handleConfirmDelete, + handleAddCredential, + deleteCredentialId, + handleSaveCredential, + handleOpenModal, + } +} diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-credential-data.ts b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-credential-data.ts new file mode 100644 index 000000000..2fbc8b103 --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-credential-data.ts @@ -0,0 +1,24 @@ +import { useMemo } from 'react' +import { useGetCredential } from './use-auth-service' +import type { + Credential, + CustomModelCredential, + ModelProvider, +} from '@/app/components/header/account-setting/model-provider-page/declarations' + +export const useCredentialData = (provider: ModelProvider, providerFormSchemaPredefined: boolean, isModelCredential?: boolean, credential?: Credential, model?: CustomModelCredential) => { + const configFrom = useMemo(() => { + if (providerFormSchemaPredefined) + return 'predefined-model' + return 'custom-model' + }, [providerFormSchemaPredefined]) + const { + isLoading, + data: credentialData = {}, + } = useGetCredential(provider.provider, isModelCredential, credential?.credential_id, model, configFrom) + + return { + isLoading, + credentialData, + } +} diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-credential-status.ts b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-credential-status.ts new file mode 100644 index 000000000..3fa3877b3 --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-credential-status.ts @@ -0,0 +1,26 @@ +import { useMemo } from 'react' +import type { + ModelProvider, +} from '../../declarations' + +export const useCredentialStatus = (provider: ModelProvider) => { + const { + current_credential_id, + current_credential_name, + available_credentials, + } = provider.custom_configuration + const hasCredential = !!available_credentials?.length + const authorized = current_credential_id && current_credential_name + const authRemoved = hasCredential && !current_credential_id && !current_credential_name + const currentCredential = available_credentials?.find(credential => credential.credential_id === current_credential_id) + + return useMemo(() => ({ + hasCredential, + authorized, + authRemoved, + current_credential_id, + current_credential_name, + available_credentials, + notAllowedToUse: currentCredential?.not_allowed_to_use, + }), [hasCredential, authorized, authRemoved, current_credential_id, current_credential_name, available_credentials]) +} diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-custom-models.ts b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-custom-models.ts new file mode 100644 index 000000000..f3b50f3f4 --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-custom-models.ts @@ -0,0 +1,9 @@ +import type { + ModelProvider, +} from '../../declarations' + +export const useCustomModels = (provider: ModelProvider) => { + const { custom_models } = provider.custom_configuration + + return custom_models || [] +} diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-model-form-schemas.ts b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-model-form-schemas.ts new file mode 100644 index 000000000..eafbedfdd --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/hooks/use-model-form-schemas.ts @@ -0,0 +1,83 @@ +import { useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import type { + Credential, + CustomModelCredential, + ModelLoadBalancingConfig, + ModelProvider, +} from '../../declarations' +import { + genModelNameFormSchema, + genModelTypeFormSchema, +} from '../../utils' +import { FormTypeEnum } from '@/app/components/base/form/types' + +export const useModelFormSchemas = ( + provider: ModelProvider, + providerFormSchemaPredefined: boolean, + credentials?: Record, + credential?: Credential, + model?: CustomModelCredential, + draftConfig?: ModelLoadBalancingConfig, +) => { + const { t } = useTranslation() + const { + provider_credential_schema, + supported_model_types, + model_credential_schema, + } = provider + const formSchemas = useMemo(() => { + const modelTypeSchema = genModelTypeFormSchema(supported_model_types) + const modelNameSchema = genModelNameFormSchema(model_credential_schema?.model) + if (!!model) { + modelTypeSchema.disabled = true + modelNameSchema.disabled = true + } + return providerFormSchemaPredefined + ? provider_credential_schema.credential_form_schemas + : [ + modelTypeSchema, + modelNameSchema, + ...(draftConfig?.enabled ? [] : model_credential_schema.credential_form_schemas), + ] + }, [ + providerFormSchemaPredefined, + provider_credential_schema?.credential_form_schemas, + supported_model_types, + model_credential_schema?.credential_form_schemas, + model_credential_schema?.model, + draftConfig?.enabled, + model, + ]) + + const formSchemasWithAuthorizationName = useMemo(() => { + const authorizationNameSchema = { + type: FormTypeEnum.textInput, + variable: '__authorization_name__', + label: t('plugin.auth.authorizationName'), + required: true, + } + + return [ + authorizationNameSchema, + ...formSchemas, + ] + }, [formSchemas, t]) + + const formValues = useMemo(() => { + let result = {} + if (credential) { + result = { ...result, __authorization_name__: credential?.credential_name } + if (credentials) + result = { ...result, ...credentials } + } + if (model) + result = { ...result, __model_name: model?.model, __model_type: model?.model_type } + return result + }, [credentials, credential, model]) + + return { + formSchemas: formSchemasWithAuthorizationName, + formValues, + } +} diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/index.tsx new file mode 100644 index 000000000..05effcea7 --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/index.tsx @@ -0,0 +1,6 @@ +export { default as Authorized } from './authorized' +export { default as SwitchCredentialInLoadBalancing } from './switch-credential-in-load-balancing' +export { default as AddCredentialInLoadBalancing } from './add-credential-in-load-balancing' +export { default as AddCustomModel } from './add-custom-model' +export { default as ConfigProvider } from './config-provider' +export { default as ConfigModel } from './config-model' diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/switch-credential-in-load-balancing.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/switch-credential-in-load-balancing.tsx new file mode 100644 index 000000000..8f81107bb --- /dev/null +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/switch-credential-in-load-balancing.tsx @@ -0,0 +1,122 @@ +import type { Dispatch, SetStateAction } from 'react' +import { + memo, + useCallback, +} from 'react' +import { useTranslation } from 'react-i18next' +import { RiArrowDownSLine } from '@remixicon/react' +import Button from '@/app/components/base/button' +import Indicator from '@/app/components/header/indicator' +import Authorized from './authorized' +import type { + Credential, + CustomModel, + ModelProvider, +} from '../declarations' +import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import cn from '@/utils/classnames' +import Tooltip from '@/app/components/base/tooltip' +import Badge from '@/app/components/base/badge' + +type SwitchCredentialInLoadBalancingProps = { + provider: ModelProvider + model: CustomModel + credentials?: Credential[] + customModelCredential?: Credential + setCustomModelCredential: Dispatch> +} +const SwitchCredentialInLoadBalancing = ({ + provider, + model, + customModelCredential, + setCustomModelCredential, + credentials, +}: SwitchCredentialInLoadBalancingProps) => { + const { t } = useTranslation() + + const handleItemClick = useCallback((credential: Credential) => { + setCustomModelCredential(credential) + }, [setCustomModelCredential]) + + const renderTrigger = useCallback(() => { + const selectedCredentialId = customModelCredential?.credential_id + const authRemoved = !selectedCredentialId && !!credentials?.length + let color = 'green' + if (authRemoved && !customModelCredential?.not_allowed_to_use) + color = 'red' + if (customModelCredential?.not_allowed_to_use) + color = 'gray' + + const Item = ( + + ) + if (customModelCredential?.not_allowed_to_use) { + return ( + + {Item} + + ) + } + return Item + }, [customModelCredential, t, credentials]) + + return ( + + ) +} + +export default memo(SwitchCredentialInLoadBalancing) diff --git a/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx index f6fb1dc6f..02c7c404a 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx @@ -13,12 +13,14 @@ type ModelIconProps = { provider?: Model | ModelProvider modelName?: string className?: string + iconClassName?: string isDeprecated?: boolean } const ModelIcon: FC = ({ provider, className, modelName, + iconClassName, isDeprecated = false, }) => { const language = useLanguage() @@ -34,7 +36,7 @@ const ModelIcon: FC = ({ if (provider?.icon_small) { return (
- model-icon + model-icon
) } @@ -44,7 +46,7 @@ const ModelIcon: FC = ({ 'flex h-5 w-5 items-center justify-center rounded-md border-[0.5px] border-components-panel-border-subtle bg-background-default-subtle', className, )}> -
+
diff --git a/web/app/components/header/account-setting/model-provider-page/model-modal/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-modal/index.tsx index bc98081df..e9050e483 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-modal/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-modal/index.tsx @@ -2,43 +2,22 @@ import type { FC } from 'react' import { memo, useCallback, - useEffect, useMemo, - useState, + useRef, } from 'react' +import { RiCloseLine } from '@remixicon/react' import { useTranslation } from 'react-i18next' -import { - RiErrorWarningFill, -} from '@remixicon/react' import type { - CredentialFormSchema, - CredentialFormSchemaRadio, - CredentialFormSchemaSelect, CustomConfigurationModelFixedFields, - FormValue, - ModelLoadBalancingConfig, - ModelLoadBalancingConfigEntry, ModelProvider, } from '../declarations' import { ConfigurationMethodEnum, - CustomConfigurationStatusEnum, FormTypeEnum, } from '../declarations' -import { - genModelNameFormSchema, - genModelTypeFormSchema, - removeCredentials, - saveCredentials, -} from '../utils' import { useLanguage, - useProviderCredentialsAndLoadBalancing, } from '../hooks' -import { useValidate } from '../../key-validator/hooks' -import { ValidatedStatus } from '../../key-validator/declarations' -import ModelLoadBalancingConfigs from '../provider-added-card/model-load-balancing-configs' -import Form from './Form' import Button from '@/app/components/base/button' import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general' @@ -46,9 +25,26 @@ import { PortalToFollowElem, PortalToFollowElemContent, } from '@/app/components/base/portal-to-follow-elem' -import { useToastContext } from '@/app/components/base/toast' import Confirm from '@/app/components/base/confirm' import { useAppContext } from '@/context/app-context' +import AuthForm from '@/app/components/base/form/form-scenarios/auth' +import type { + FormRefObject, + FormSchema, +} from '@/app/components/base/form/types' +import { useModelFormSchemas } from '../model-auth/hooks' +import type { + Credential, + CustomModel, +} from '../declarations' +import Loading from '@/app/components/base/loading' +import { + useAuth, + useCredentialData, +} from '@/app/components/header/account-setting/model-provider-page/model-auth/hooks' +import ModelIcon from '@/app/components/header/account-setting/model-provider-page/model-icon' +import Badge from '@/app/components/base/badge' +import { useRenderI18nObject } from '@/hooks/use-i18n' type ModelModalProps = { provider: ModelProvider @@ -56,6 +52,9 @@ type ModelModalProps = { currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields onCancel: () => void onSave: () => void + model?: CustomModel + credential?: Credential + isModelCredential?: boolean } const ModelModal: FC = ({ @@ -64,244 +63,173 @@ const ModelModal: FC = ({ currentCustomConfigurationModelFixedFields, onCancel, onSave, + model, + credential, + isModelCredential, }) => { + const renderI18nObject = useRenderI18nObject() const providerFormSchemaPredefined = configurateMethod === ConfigurationMethodEnum.predefinedModel + const { + isLoading, + credentialData, + } = useCredentialData(provider, providerFormSchemaPredefined, isModelCredential, credential, model) + const { + handleSaveCredential, + handleConfirmDelete, + deleteCredentialId, + closeConfirmDelete, + openConfirmDelete, + doingAction, + } = useAuth(provider, configurateMethod, currentCustomConfigurationModelFixedFields, isModelCredential, onSave) const { credentials: formSchemasValue, - loadBalancing: originalConfig, - mutate, - } = useProviderCredentialsAndLoadBalancing( - provider.provider, - configurateMethod, - providerFormSchemaPredefined && provider.custom_configuration.status === CustomConfigurationStatusEnum.active, - currentCustomConfigurationModelFixedFields, - ) + } = credentialData as any + const { isCurrentWorkspaceManager } = useAppContext() const isEditMode = !!formSchemasValue && isCurrentWorkspaceManager const { t } = useTranslation() - const { notify } = useToastContext() const language = useLanguage() - const [loading, setLoading] = useState(false) - const [showConfirm, setShowConfirm] = useState(false) + const { + formSchemas, + formValues, + } = useModelFormSchemas(provider, providerFormSchemaPredefined, formSchemasValue, credential, model) + const formRef = useRef(null) - const [draftConfig, setDraftConfig] = useState() - const originalConfigMap = useMemo(() => { - if (!originalConfig) - return {} - return originalConfig?.configs.reduce((prev, config) => { - if (config.id) - prev[config.id] = config - return prev - }, {} as Record) - }, [originalConfig]) - useEffect(() => { - if (originalConfig && !draftConfig) - setDraftConfig(originalConfig) - }, [draftConfig, originalConfig]) + const handleSave = useCallback(async () => { + const { + isCheckValidated, + values, + } = formRef.current?.getFormValues({ + needCheckValidatedValues: true, + needTransformWhenSecretFieldIsPristine: true, + }) || { isCheckValidated: false, values: {} } + if (!isCheckValidated) + return - const formSchemas = useMemo(() => { - return providerFormSchemaPredefined - ? provider.provider_credential_schema.credential_form_schemas - : [ - genModelTypeFormSchema(provider.supported_model_types), - genModelNameFormSchema(provider.model_credential_schema?.model), - ...(draftConfig?.enabled ? [] : provider.model_credential_schema.credential_form_schemas), - ] - }, [ - providerFormSchemaPredefined, - provider.provider_credential_schema?.credential_form_schemas, - provider.supported_model_types, - provider.model_credential_schema?.credential_form_schemas, - provider.model_credential_schema?.model, - draftConfig?.enabled, - ]) - const [ - requiredFormSchemas, - defaultFormSchemaValue, - showOnVariableMap, - ] = useMemo(() => { - const requiredFormSchemas: CredentialFormSchema[] = [] - const defaultFormSchemaValue: Record = {} - const showOnVariableMap: Record = {} + const { + __authorization_name__, + __model_name, + __model_type, + ...rest + } = values + if (__model_name && __model_type) { + handleSaveCredential({ + credential_id: credential?.credential_id, + credentials: rest, + name: __authorization_name__, + model: __model_name, + model_type: __model_type, + }) + } + else { + handleSaveCredential({ + credential_id: credential?.credential_id, + credentials: rest, + name: __authorization_name__, + }) + } + }, [handleSaveCredential, credential?.credential_id, model]) - formSchemas.forEach((formSchema) => { - if (formSchema.required) - requiredFormSchemas.push(formSchema) - - if (formSchema.default) - defaultFormSchemaValue[formSchema.variable] = formSchema.default - - if (formSchema.show_on.length) { - formSchema.show_on.forEach((showOnItem) => { - if (!showOnVariableMap[showOnItem.variable]) - showOnVariableMap[showOnItem.variable] = [] - - if (!showOnVariableMap[showOnItem.variable].includes(formSchema.variable)) - showOnVariableMap[showOnItem.variable].push(formSchema.variable) - }) - } - - if (formSchema.type === FormTypeEnum.select || formSchema.type === FormTypeEnum.radio) { - (formSchema as (CredentialFormSchemaRadio | CredentialFormSchemaSelect)).options.forEach((option) => { - if (option.show_on.length) { - option.show_on.forEach((showOnItem) => { - if (!showOnVariableMap[showOnItem.variable]) - showOnVariableMap[showOnItem.variable] = [] - - if (!showOnVariableMap[showOnItem.variable].includes(formSchema.variable)) - showOnVariableMap[showOnItem.variable].push(formSchema.variable) - }) - } - }) - } - }) - - return [ - requiredFormSchemas, - defaultFormSchemaValue, - showOnVariableMap, - ] - }, [formSchemas]) - const initialFormSchemasValue: Record = useMemo(() => { - return { - ...defaultFormSchemaValue, - ...formSchemasValue, - } as unknown as Record - }, [formSchemasValue, defaultFormSchemaValue]) - const [value, setValue] = useState(initialFormSchemasValue) - useEffect(() => { - setValue(initialFormSchemasValue) - }, [initialFormSchemasValue]) - const [_, validating, validatedStatusState] = useValidate(value) - const filteredRequiredFormSchemas = requiredFormSchemas.filter((requiredFormSchema) => { - if (requiredFormSchema.show_on.length && requiredFormSchema.show_on.every(showOnItem => value[showOnItem.variable] === showOnItem.value)) - return true - - if (!requiredFormSchema.show_on.length) - return true - - return false - }) - - const handleValueChange = (v: FormValue) => { - setValue(v) - } - - const extendedSecretFormSchemas = useMemo( - () => - (providerFormSchemaPredefined - ? provider.provider_credential_schema.credential_form_schemas - : [ - genModelTypeFormSchema(provider.supported_model_types), - genModelNameFormSchema(provider.model_credential_schema?.model), - ...provider.model_credential_schema.credential_form_schemas, - ]).filter(({ type }) => type === FormTypeEnum.secretInput), - [ - provider.model_credential_schema?.credential_form_schemas, - provider.model_credential_schema?.model, - provider.provider_credential_schema?.credential_form_schemas, - provider.supported_model_types, - providerFormSchemaPredefined, - ], - ) - - const encodeSecretValues = useCallback((v: FormValue) => { - const result = { ...v } - extendedSecretFormSchemas.forEach(({ variable }) => { - if (result[variable] === formSchemasValue?.[variable] && result[variable] !== undefined) - result[variable] = '[__HIDDEN__]' - }) - return result - }, [extendedSecretFormSchemas, formSchemasValue]) - - const encodeConfigEntrySecretValues = useCallback((entry: ModelLoadBalancingConfigEntry) => { - const result = { ...entry } - extendedSecretFormSchemas.forEach(({ variable }) => { - if (entry.id && result.credentials[variable] === originalConfigMap[entry.id]?.credentials?.[variable]) - result.credentials[variable] = '[__HIDDEN__]' - }) - return result - }, [extendedSecretFormSchemas, originalConfigMap]) - - const handleSave = async () => { - try { - setLoading(true) - const res = await saveCredentials( - providerFormSchemaPredefined, - provider.provider, - encodeSecretValues(value), - { - ...draftConfig, - enabled: Boolean(draftConfig?.enabled), - configs: draftConfig?.configs.map(encodeConfigEntrySecretValues) || [], - }, + const modalTitle = useMemo(() => { + if (!providerFormSchemaPredefined && !model) { + return ( +
+ +
+
{t('common.modelProvider.auth.apiKeyModal.addModel')}
+
{renderI18nObject(provider.label)}
+
+
) - if (res.result === 'success') { - notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) - mutate() - onSave() - onCancel() - } } - finally { - setLoading(false) - } - } + let label = t('common.modelProvider.auth.apiKeyModal.title') - const handleRemove = async () => { - try { - setLoading(true) + if (model) + label = t('common.modelProvider.auth.addModelCredential') - const res = await removeCredentials( - providerFormSchemaPredefined, - provider.provider, - value, + return ( +
+ {label} +
+ ) + }, [providerFormSchemaPredefined, t, model, renderI18nObject]) + + const modalDesc = useMemo(() => { + if (providerFormSchemaPredefined) { + return ( +
+ {t('common.modelProvider.auth.apiKeyModal.desc')} +
) - if (res.result === 'success') { - notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) - mutate() - onSave() - onCancel() - } } - finally { - setLoading(false) - } - } - const renderTitlePrefix = () => { - const prefix = isEditMode ? t('common.operation.setup') : t('common.operation.add') - return `${prefix} ${provider.label[language] || provider.label.en_US}` - } + return null + }, [providerFormSchemaPredefined, t]) + + const modalModel = useMemo(() => { + if (model) { + return ( +
+ +
{model.model}
+ {model.model_type} +
+ ) + } + + return null + }, [model, provider]) return (
-
-
-
-
{renderTitlePrefix()}
+
+
+ +
+
+
+ {modalTitle} + {modalDesc} + {modalModel}
-
-
- + { + isLoading && ( +
+ +
+ ) + } + { + !isLoading && ( + { + return { + ...formSchema, + name: formSchema.variable, + showRadioUI: formSchema.type === FormTypeEnum.radio, + } + }) as FormSchema[]} + defaultValues={formValues} + inputClassName='justify-start' + ref={formRef} + /> + ) + }
@@ -327,7 +255,7 @@ const ModelModal: FC = ({ variant='warning' size='large' className='mr-2' - onClick={() => setShowConfirm(true)} + onClick={() => openConfirmDelete(credential, model)} > {t('common.operation.remove')} @@ -344,12 +272,7 @@ const ModelModal: FC = ({ size='large' variant='primary' onClick={handleSave} - disabled={ - loading - || filteredRequiredFormSchemas.some(item => value[item.variable] === undefined) - || (draftConfig?.enabled && (draftConfig?.configs.filter(config => config.enabled).length ?? 0) < 2) - } - + disabled={isLoading || doingAction} > {t('common.operation.save')} @@ -357,38 +280,28 @@ const ModelModal: FC = ({
- { - (validatedStatusState.status === ValidatedStatus.Error && validatedStatusState.message) - ? ( -
- - {validatedStatusState.message} -
- ) - : ( -
- - {t('common.modelProvider.encrypted.front')} - - PKCS1_OAEP - - {t('common.modelProvider.encrypted.back')} -
- ) - } +
+ + {t('common.modelProvider.encrypted.front')} + + PKCS1_OAEP + + {t('common.modelProvider.encrypted.back')} +
{ - showConfirm && ( + deleteCredentialId && ( setShowConfirm(false)} - onConfirm={handleRemove} + isDisabled={doingAction} + onCancel={closeConfirmDelete} + onConfirm={handleConfirmDelete} /> ) } diff --git a/web/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal.tsx b/web/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal.tsx deleted file mode 100644 index d6285a784..000000000 --- a/web/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal.tsx +++ /dev/null @@ -1,348 +0,0 @@ -import type { FC } from 'react' -import { - memo, - useCallback, - useEffect, - useMemo, - useState, -} from 'react' -import { useTranslation } from 'react-i18next' -import { - RiErrorWarningFill, -} from '@remixicon/react' -import type { - CredentialFormSchema, - CredentialFormSchemaRadio, - CredentialFormSchemaSelect, - CredentialFormSchemaTextInput, - CustomConfigurationModelFixedFields, - FormValue, - ModelLoadBalancingConfigEntry, - ModelProvider, -} from '../declarations' -import { - ConfigurationMethodEnum, - FormTypeEnum, -} from '../declarations' - -import { - useLanguage, -} from '../hooks' -import { useValidate } from '../../key-validator/hooks' -import { ValidatedStatus } from '../../key-validator/declarations' -import { validateLoadBalancingCredentials } from '../utils' -import Form from './Form' -import Button from '@/app/components/base/button' -import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' -import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general' -import { - PortalToFollowElem, - PortalToFollowElemContent, -} from '@/app/components/base/portal-to-follow-elem' -import { useToastContext } from '@/app/components/base/toast' -import Confirm from '@/app/components/base/confirm' - -type ModelModalProps = { - provider: ModelProvider - configurationMethod: ConfigurationMethodEnum - currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields - entry?: ModelLoadBalancingConfigEntry - onCancel: () => void - onSave: (entry: ModelLoadBalancingConfigEntry) => void - onRemove: () => void -} - -const ModelLoadBalancingEntryModal: FC = ({ - provider, - configurationMethod, - currentCustomConfigurationModelFixedFields, - entry, - onCancel, - onSave, - onRemove, -}) => { - const providerFormSchemaPredefined = configurationMethod === ConfigurationMethodEnum.predefinedModel - // const { credentials: formSchemasValue } = useProviderCredentialsAndLoadBalancing( - // provider.provider, - // configurationMethod, - // providerFormSchemaPredefined && provider.custom_configuration.status === CustomConfigurationStatusEnum.active, - // currentCustomConfigurationModelFixedFields, - // ) - const isEditMode = !!entry - const { t } = useTranslation() - const { notify } = useToastContext() - const language = useLanguage() - const [loading, setLoading] = useState(false) - const [showConfirm, setShowConfirm] = useState(false) - const formSchemas = useMemo(() => { - return [ - { - type: FormTypeEnum.textInput, - label: { - en_US: 'Config Name', - zh_Hans: '配置名称', - }, - variable: 'name', - required: true, - show_on: [], - placeholder: { - en_US: 'Enter your Config Name here', - zh_Hans: '输入配置名称', - }, - } as CredentialFormSchemaTextInput, - ...( - providerFormSchemaPredefined - ? provider.provider_credential_schema.credential_form_schemas - : provider.model_credential_schema.credential_form_schemas - ), - ] - }, [ - providerFormSchemaPredefined, - provider.provider_credential_schema?.credential_form_schemas, - provider.model_credential_schema?.credential_form_schemas, - ]) - - const [ - requiredFormSchemas, - secretFormSchemas, - defaultFormSchemaValue, - showOnVariableMap, - ] = useMemo(() => { - const requiredFormSchemas: CredentialFormSchema[] = [] - const secretFormSchemas: CredentialFormSchema[] = [] - const defaultFormSchemaValue: Record = {} - const showOnVariableMap: Record = {} - - formSchemas.forEach((formSchema) => { - if (formSchema.required) - requiredFormSchemas.push(formSchema) - - if (formSchema.type === FormTypeEnum.secretInput) - secretFormSchemas.push(formSchema) - - if (formSchema.default) - defaultFormSchemaValue[formSchema.variable] = formSchema.default - - if (formSchema.show_on.length) { - formSchema.show_on.forEach((showOnItem) => { - if (!showOnVariableMap[showOnItem.variable]) - showOnVariableMap[showOnItem.variable] = [] - - if (!showOnVariableMap[showOnItem.variable].includes(formSchema.variable)) - showOnVariableMap[showOnItem.variable].push(formSchema.variable) - }) - } - - if (formSchema.type === FormTypeEnum.select || formSchema.type === FormTypeEnum.radio) { - (formSchema as (CredentialFormSchemaRadio | CredentialFormSchemaSelect)).options.forEach((option) => { - if (option.show_on.length) { - option.show_on.forEach((showOnItem) => { - if (!showOnVariableMap[showOnItem.variable]) - showOnVariableMap[showOnItem.variable] = [] - - if (!showOnVariableMap[showOnItem.variable].includes(formSchema.variable)) - showOnVariableMap[showOnItem.variable].push(formSchema.variable) - }) - } - }) - } - }) - - return [ - requiredFormSchemas, - secretFormSchemas, - defaultFormSchemaValue, - showOnVariableMap, - ] - }, [formSchemas]) - const [initialValue, setInitialValue] = useState() - useEffect(() => { - if (entry && !initialValue) { - setInitialValue({ - ...defaultFormSchemaValue, - ...entry.credentials, - id: entry.id, - name: entry.name, - } as Record) - } - }, [entry, defaultFormSchemaValue, initialValue]) - const formSchemasValue = useMemo(() => ({ - ...currentCustomConfigurationModelFixedFields, - ...initialValue, - }), [currentCustomConfigurationModelFixedFields, initialValue]) - const initialFormSchemasValue: Record = useMemo(() => { - return { - ...defaultFormSchemaValue, - ...formSchemasValue, - } as Record - }, [formSchemasValue, defaultFormSchemaValue]) - const [value, setValue] = useState(initialFormSchemasValue) - useEffect(() => { - setValue(initialFormSchemasValue) - }, [initialFormSchemasValue]) - const [_, validating, validatedStatusState] = useValidate(value) - const filteredRequiredFormSchemas = requiredFormSchemas.filter((requiredFormSchema) => { - if (requiredFormSchema.show_on.length && requiredFormSchema.show_on.every(showOnItem => value[showOnItem.variable] === showOnItem.value)) - return true - - if (!requiredFormSchema.show_on.length) - return true - - return false - }) - const getSecretValues = useCallback((v: FormValue) => { - return secretFormSchemas.reduce((prev, next) => { - if (isEditMode && v[next.variable] && v[next.variable] === initialFormSchemasValue[next.variable]) - prev[next.variable] = '[__HIDDEN__]' - - return prev - }, {} as Record) - }, [initialFormSchemasValue, isEditMode, secretFormSchemas]) - - // const handleValueChange = ({ __model_type, __model_name, ...v }: FormValue) => { - const handleValueChange = (v: FormValue) => { - setValue(v) - } - const handleSave = async () => { - try { - setLoading(true) - - const res = await validateLoadBalancingCredentials( - providerFormSchemaPredefined, - provider.provider, - { - ...value, - ...getSecretValues(value), - }, - entry?.id, - ) - if (res.status === ValidatedStatus.Success) { - // notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) - const { __model_type, __model_name, name, ...credentials } = value - onSave({ - ...(entry || {}), - name: name as string, - credentials: credentials as Record, - }) - // onCancel() - } - else { - notify({ type: 'error', message: res.message || '' }) - } - } - finally { - setLoading(false) - } - } - - const handleRemove = () => { - onRemove?.() - } - - return ( - - -
-
-
-
-
{t(isEditMode ? 'common.modelProvider.editConfig' : 'common.modelProvider.addConfig')}
-
- -
- { - (provider.help && (provider.help.title || provider.help.url)) - ? ( - !provider.help.url && e.preventDefault()} - > - {provider.help.title?.[language] || provider.help.url[language] || provider.help.title?.en_US || provider.help.url.en_US} - - - ) - :
- } -
- { - isEditMode && ( - - ) - } - - -
-
-
-
- { - (validatedStatusState.status === ValidatedStatus.Error && validatedStatusState.message) - ? ( -
- - {validatedStatusState.message} -
- ) - : ( -
- - {t('common.modelProvider.encrypted.front')} - - PKCS1_OAEP - - {t('common.modelProvider.encrypted.back')} -
- ) - } -
-
- { - showConfirm && ( - setShowConfirm(false)} - onConfirm={handleRemove} - /> - ) - } -
- - - ) -} - -export default memo(ModelLoadBalancingEntryModal) diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx index 822df5f72..d57288db3 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx @@ -1,7 +1,8 @@ -import type { FC } from 'react' +import { useMemo } from 'react' import { useTranslation } from 'react-i18next' -import { RiEqualizer2Line } from '@remixicon/react' -import type { ModelProvider } from '../declarations' +import type { + ModelProvider, +} from '../declarations' import { ConfigurationMethodEnum, CustomConfigurationStatusEnum, @@ -15,19 +16,19 @@ import PrioritySelector from './priority-selector' import PriorityUseTip from './priority-use-tip' import { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './index' import Indicator from '@/app/components/header/indicator' -import Button from '@/app/components/base/button' import { changeModelProviderPriority } from '@/service/common' import { useToastContext } from '@/app/components/base/toast' import { useEventEmitterContextContext } from '@/context/event-emitter' +import cn from '@/utils/classnames' +import { useCredentialStatus } from '@/app/components/header/account-setting/model-provider-page/model-auth/hooks' +import { ConfigProvider } from '@/app/components/header/account-setting/model-provider-page/model-auth' type CredentialPanelProps = { provider: ModelProvider - onSetup: () => void } -const CredentialPanel: FC = ({ +const CredentialPanel = ({ provider, - onSetup, -}) => { +}: CredentialPanelProps) => { const { t } = useTranslation() const { notify } = useToastContext() const { eventEmitter } = useEventEmitterContextContext() @@ -38,6 +39,13 @@ const CredentialPanel: FC = ({ const priorityUseType = provider.preferred_provider_type const isCustomConfigured = customConfig.status === CustomConfigurationStatusEnum.active const configurateMethods = provider.configurate_methods + const { + hasCredential, + authorized, + authRemoved, + current_credential_name, + notAllowedToUse, + } = useCredentialStatus(provider) const handleChangePriority = async (key: PreferredProviderTypeEnum) => { const res = await changeModelProviderPriority({ @@ -61,25 +69,50 @@ const CredentialPanel: FC = ({ } as any) } } + const credentialLabel = useMemo(() => { + if (!hasCredential) + return t('common.modelProvider.auth.unAuthorized') + if (authorized) + return current_credential_name + if (authRemoved) + return t('common.modelProvider.auth.authRemoved') + + return '' + }, [authorized, authRemoved, current_credential_name, hasCredential]) + + const color = useMemo(() => { + if (authRemoved) + return 'red' + if (notAllowedToUse) + return 'gray' + return 'green' + }, [authRemoved, notAllowedToUse]) return ( <> { provider.provider_credential_schema && ( -
-
- API-KEY - +
+
+
+ {credentialLabel} +
+
- + { systemConfig.enabled && isCustomConfigured && ( void } const ProviderAddedCard: FC = ({ notConfigured, provider, - onOpenModal, }) => { const { t } = useTranslation() const { eventEmitter } = useEventEmitterContextContext() @@ -114,7 +111,6 @@ const ProviderAddedCard: FC = ({ { showCredential && ( onOpenModal(ConfigurationMethodEnum.predefinedModel)} provider={provider} /> ) @@ -159,9 +155,9 @@ const ProviderAddedCard: FC = ({ )} { configurationMethods.includes(ConfigurationMethodEnum.customizableModel) && isCurrentWorkspaceManager && ( - onOpenModal(ConfigurationMethodEnum.customizableModel)} - className='flex' + ) } @@ -174,7 +170,6 @@ const ProviderAddedCard: FC = ({ provider={provider} models={modelList} onCollapse={() => setCollapsed(true)} - onConfig={currentCustomConfigurationModelFixedFields => onOpenModal(ConfigurationMethodEnum.customizableModel, currentCustomConfigurationModelFixedFields)} onChange={(provider: string) => getModelList(provider)} /> ) diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list-item.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list-item.tsx index 8908d9a03..bcd483244 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list-item.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list-item.tsx @@ -1,31 +1,29 @@ import { memo, useCallback } from 'react' import { useTranslation } from 'react-i18next' import { useDebounceFn } from 'ahooks' -import type { CustomConfigurationModelFixedFields, ModelItem, ModelProvider } from '../declarations' -import { ConfigurationMethodEnum, ModelStatusEnum } from '../declarations' -import ModelBadge from '../model-badge' +import type { ModelItem, ModelProvider } from '../declarations' +import { ModelStatusEnum } from '../declarations' import ModelIcon from '../model-icon' import ModelName from '../model-name' import classNames from '@/utils/classnames' -import Button from '@/app/components/base/button' import { Balance } from '@/app/components/base/icons/src/vender/line/financeAndECommerce' -import { Settings01 } from '@/app/components/base/icons/src/vender/line/general' import Switch from '@/app/components/base/switch' import Tooltip from '@/app/components/base/tooltip' import { useProviderContext, useProviderContextSelector } from '@/context/provider-context' import { disableModel, enableModel } from '@/service/common' import { Plan } from '@/app/components/billing/type' import { useAppContext } from '@/context/app-context' +import { ConfigModel } from '../model-auth' +import Badge from '@/app/components/base/badge' export type ModelListItemProps = { model: ModelItem provider: ModelProvider isConfigurable: boolean - onConfig: (currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => void onModifyLoadBalancing?: (model: ModelItem) => void } -const ModelListItem = ({ model, provider, isConfigurable, onConfig, onModifyLoadBalancing }: ModelListItemProps) => { +const ModelListItem = ({ model, provider, isConfigurable, onModifyLoadBalancing }: ModelListItemProps) => { const { t } = useTranslation() const { plan } = useProviderContext() const modelLoadBalancingEnabled = useProviderContextSelector(state => state.modelLoadBalancingEnabled) @@ -46,7 +44,7 @@ const ModelListItem = ({ model, provider, isConfigurable, onConfig, onModifyLoad return (
- {modelLoadBalancingEnabled && !model.deprecated && model.load_balancing_enabled && ( - - - {t('common.modelProvider.loadBalancingHeadline')} - - )}
+ {modelLoadBalancingEnabled && !model.deprecated && model.load_balancing_enabled && !model.has_invalid_load_balancing_configs && ( + + + + )} { - model.fetch_from === ConfigurationMethodEnum.customizableModel - ? (isCurrentWorkspaceManager && ( - - )) - : (isCurrentWorkspaceManager && (modelLoadBalancingEnabled || plan.type === Plan.sandbox) && !model.deprecated && [ModelStatusEnum.active, ModelStatusEnum.disabled].includes(model.status)) - ? ( - - ) - : null + (isCurrentWorkspaceManager && (modelLoadBalancingEnabled || plan.type === Plan.sandbox) && !model.deprecated && [ModelStatusEnum.active, ModelStatusEnum.disabled].includes(model.status)) && ( + onModifyLoadBalancing?.(model)} + loadBalancingEnabled={model.load_balancing_enabled} + loadBalancingInvalid={model.has_invalid_load_balancing_configs} + credentialRemoved={model.status === ModelStatusEnum.credentialRemoved} + /> + ) } { model.deprecated diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list.tsx index 699be6edd..8d902043f 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list.tsx @@ -5,7 +5,7 @@ import { RiArrowRightSLine, } from '@remixicon/react' import type { - CustomConfigurationModelFixedFields, + Credential, ModelItem, ModelProvider, } from '../declarations' @@ -13,34 +13,33 @@ import { ConfigurationMethodEnum, } from '../declarations' // import Tab from './tab' -import AddModelButton from './add-model-button' import ModelListItem from './model-list-item' import { useModalContextSelector } from '@/context/modal-context' import { useAppContext } from '@/context/app-context' +import { AddCustomModel } from '@/app/components/header/account-setting/model-provider-page/model-auth' type ModelListProps = { provider: ModelProvider models: ModelItem[] onCollapse: () => void - onConfig: (currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => void onChange?: (provider: string) => void } const ModelList: FC = ({ provider, models, onCollapse, - onConfig, onChange, }) => { const { t } = useTranslation() const configurativeMethods = provider.configurate_methods.filter(method => method !== ConfigurationMethodEnum.fetchFromRemote) const { isCurrentWorkspaceManager } = useAppContext() const isConfigurable = configurativeMethods.includes(ConfigurationMethodEnum.customizableModel) - const setShowModelLoadBalancingModal = useModalContextSelector(state => state.setShowModelLoadBalancingModal) - const onModifyLoadBalancing = useCallback((model: ModelItem) => { + const onModifyLoadBalancing = useCallback((model: ModelItem, credential?: Credential) => { setShowModelLoadBalancingModal({ provider, + credential, + configurateMethod: model.fetch_from, model: model!, open: !!model, onClose: () => setShowModelLoadBalancingModal(null), @@ -65,17 +64,14 @@ const ModelList: FC = ({ - {/* { - isConfigurable && canSystemConfig && ( - - {}} /> - - ) - } */} { isConfigurable && isCurrentWorkspaceManager && (
- onConfig()} /> +
) } @@ -83,12 +79,11 @@ const ModelList: FC = ({ { models.map(model => ( diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-configs.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-configs.tsx index 1a3039659..f92c188aa 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-configs.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-configs.tsx @@ -1,24 +1,35 @@ import type { Dispatch, SetStateAction } from 'react' -import { useCallback } from 'react' +import { useCallback, useMemo } from 'react' import { useTranslation } from 'react-i18next' import { RiDeleteBinLine, + RiEqualizer2Line, } from '@remixicon/react' -import type { ConfigurationMethodEnum, CustomConfigurationModelFixedFields, ModelLoadBalancingConfig, ModelLoadBalancingConfigEntry, ModelProvider } from '../declarations' +import type { + Credential, + CustomConfigurationModelFixedFields, + CustomModelCredential, + ModelCredential, + ModelLoadBalancingConfig, + ModelLoadBalancingConfigEntry, + ModelProvider, +} from '../declarations' +import { ConfigurationMethodEnum } from '../declarations' import Indicator from '../../../indicator' import CooldownTimer from './cooldown-timer' import classNames from '@/utils/classnames' import Tooltip from '@/app/components/base/tooltip' import Switch from '@/app/components/base/switch' import { Balance } from '@/app/components/base/icons/src/vender/line/financeAndECommerce' -import { Edit02, Plus02 } from '@/app/components/base/icons/src/vender/line/general' import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' -import { useModalContextSelector } from '@/context/modal-context' import UpgradeBtn from '@/app/components/billing/upgrade-btn' import s from '@/app/components/custom/style.module.css' import GridMask from '@/app/components/base/grid-mask' import { useProviderContextSelector } from '@/context/provider-context' import { IS_CE_EDITION } from '@/config' +import { AddCredentialInLoadBalancing } from '@/app/components/header/account-setting/model-provider-page/model-auth' +import { useModelModalHandler } from '@/app/components/header/account-setting/model-provider-page/hooks' +import Badge from '@/app/components/base/badge/index' export type ModelLoadBalancingConfigsProps = { draftConfig?: ModelLoadBalancingConfig @@ -28,19 +39,27 @@ export type ModelLoadBalancingConfigsProps = { currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields withSwitch?: boolean className?: string + modelCredential: ModelCredential + onUpdate?: () => void + model: CustomModelCredential } const ModelLoadBalancingConfigs = ({ draftConfig, setDraftConfig, provider, + model, configurationMethod, currentCustomConfigurationModelFixedFields, withSwitch = false, className, + modelCredential, + onUpdate, }: ModelLoadBalancingConfigsProps) => { const { t } = useTranslation() + const providerFormSchemaPredefined = configurationMethod === ConfigurationMethodEnum.predefinedModel const modelLoadBalancingEnabled = useProviderContextSelector(state => state.modelLoadBalancingEnabled) + const handleOpenModal = useModelModalHandler() const updateConfigEntry = useCallback( ( @@ -65,6 +84,21 @@ const ModelLoadBalancingConfigs = ({ [setDraftConfig], ) + const addConfigEntry = useCallback((credential: Credential) => { + setDraftConfig((prev: any) => { + if (!prev) + return prev + return { + ...prev, + configs: [...prev.configs, { + credential_id: credential.credential_id, + enabled: true, + name: credential.credential_name, + }], + } + }) + }, [setDraftConfig]) + const toggleModalBalancing = useCallback((enabled: boolean) => { if ((modelLoadBalancingEnabled || !enabled) && draftConfig) { setDraftConfig({ @@ -81,54 +115,6 @@ const ModelLoadBalancingConfigs = ({ })) }, [updateConfigEntry]) - const setShowModelLoadBalancingEntryModal = useModalContextSelector(state => state.setShowModelLoadBalancingEntryModal) - - const toggleEntryModal = useCallback((index?: number, entry?: ModelLoadBalancingConfigEntry) => { - setShowModelLoadBalancingEntryModal({ - payload: { - currentProvider: provider, - currentConfigurationMethod: configurationMethod, - currentCustomConfigurationModelFixedFields, - entry, - index, - }, - onSaveCallback: ({ entry: result }) => { - if (entry) { - // edit - setDraftConfig(prev => ({ - ...prev, - enabled: !!prev?.enabled, - configs: prev?.configs.map((config, i) => i === index ? result! : config) || [], - })) - } - else { - // add - setDraftConfig(prev => ({ - ...prev, - enabled: !!prev?.enabled, - configs: (prev?.configs || []).concat([{ ...result!, enabled: true }]), - })) - } - }, - onRemoveCallback: ({ index }) => { - if (index !== undefined && (draftConfig?.configs?.length ?? 0) > index) { - setDraftConfig(prev => ({ - ...prev, - enabled: !!prev?.enabled, - configs: prev?.configs.filter((_, i) => i !== index) || [], - })) - } - }, - }) - }, [ - configurationMethod, - currentCustomConfigurationModelFixedFields, - draftConfig?.configs?.length, - provider, - setDraftConfig, - setShowModelLoadBalancingEntryModal, - ]) - const clearCountdown = useCallback((index: number) => { updateConfigEntry(index, ({ ttl: _, ...entry }) => { return { @@ -138,6 +124,12 @@ const ModelLoadBalancingConfigs = ({ }) }, [updateConfigEntry]) + const validDraftConfigList = useMemo(() => { + if (!draftConfig) + return [] + return draftConfig.configs + }, [draftConfig]) + if (!draftConfig) return null @@ -181,8 +173,9 @@ const ModelLoadBalancingConfigs = ({
{draftConfig.enabled && (
- {draftConfig.configs.map((config, index) => { + {validDraftConfigList.map((config, index) => { const isProviderManaged = config.name === '__inherit__' + const credential = modelCredential.available_credentials.find(c => c.credential_id === config.credential_id) return (
@@ -200,54 +193,81 @@ const ModelLoadBalancingConfigs = ({
{isProviderManaged ? t('common.modelProvider.defaultConfig') : config.name}
- {isProviderManaged && ( - {t('common.modelProvider.providerManaged')} + {isProviderManaged && providerFormSchemaPredefined && ( + {t('common.modelProvider.providerManaged')} )} + { + credential?.from_enterprise && ( + Enterprise + ) + }
{!isProviderManaged && ( <>
- toggleEntryModal(index, config)} - > - - + { + config.credential_id && !credential?.not_allowed_to_use && !credential?.from_enterprise && ( + { + handleOpenModal( + provider, + configurationMethod, + currentCustomConfigurationModelFixedFields, + configurationMethod === ConfigurationMethodEnum.customizableModel, + (config.credential_id && config.name) ? { + credential_id: config.credential_id, + credential_name: config.name, + } : undefined, + model, + ) + }} + > + + + ) + } updateConfigEntry(index, () => undefined)} > -
)} - toggleConfigEntryEnabled(index, value)} - /> + { + (config.credential_id || config.name === '__inherit__') && ( + <> + + toggleConfigEntryEnabled(index, value)} + disabled={credential?.not_allowed_to_use} + /> + + ) + }
) })} - -
toggleEntryModal()} - > -
- {t('common.modelProvider.addConfig')} -
-
+
)} { - draftConfig.enabled && draftConfig.configs.length < 2 && ( -
+ draftConfig.enabled && validDraftConfigList.length < 2 && ( +
{t('common.modelProvider.loadBalancingLeastKeyWarning')}
diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx index 9fb07401f..1d6db30c4 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx @@ -1,40 +1,69 @@ import { memo, useCallback, useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' -import useSWR from 'swr' -import type { ModelItem, ModelLoadBalancingConfig, ModelLoadBalancingConfigEntry, ModelProvider } from '../declarations' -import { FormTypeEnum } from '../declarations' +import type { + Credential, + ModelItem, + ModelLoadBalancingConfig, + ModelLoadBalancingConfigEntry, + ModelProvider, +} from '../declarations' +import { + ConfigurationMethodEnum, + FormTypeEnum, +} from '../declarations' import ModelIcon from '../model-icon' import ModelName from '../model-name' -import { savePredefinedLoadBalancingConfig } from '../utils' import ModelLoadBalancingConfigs from './model-load-balancing-configs' import classNames from '@/utils/classnames' import Modal from '@/app/components/base/modal' import Button from '@/app/components/base/button' -import { fetchModelLoadBalancingConfig } from '@/service/common' import Loading from '@/app/components/base/loading' import { useToastContext } from '@/app/components/base/toast' +import { SwitchCredentialInLoadBalancing } from '@/app/components/header/account-setting/model-provider-page/model-auth' +import { + useGetModelCredential, + useUpdateModelLoadBalancingConfig, +} from '@/service/use-models' export type ModelLoadBalancingModalProps = { provider: ModelProvider + configurateMethod: ConfigurationMethodEnum model: ModelItem + credential?: Credential open?: boolean onClose?: () => void onSave?: (provider: string) => void } // model balancing config modal -const ModelLoadBalancingModal = ({ provider, model, open = false, onClose, onSave }: ModelLoadBalancingModalProps) => { +const ModelLoadBalancingModal = ({ + provider, + configurateMethod, + model, + credential, + open = false, + onClose, + onSave, +}: ModelLoadBalancingModalProps) => { const { t } = useTranslation() const { notify } = useToastContext() const [loading, setLoading] = useState(false) - - const { data, mutate } = useSWR( - `/workspaces/current/model-providers/${provider.provider}/models/credentials?model=${model.model}&model_type=${model.model_type}`, - fetchModelLoadBalancingConfig, - ) - - const originalConfig = data?.load_balancing + const providerFormSchemaPredefined = configurateMethod === ConfigurationMethodEnum.predefinedModel + const configFrom = providerFormSchemaPredefined ? 'predefined-model' : 'custom-model' + const { + isLoading, + data, + refetch, + } = useGetModelCredential(true, provider.provider, credential?.credential_id, model.model, model.model_type, configFrom) + const modelCredential = data + const { + load_balancing, + current_credential_id, + available_credentials, + current_credential_name, + } = modelCredential ?? {} + const originalConfig = load_balancing const [draftConfig, setDraftConfig] = useState() const originalConfigMap = useMemo(() => { if (!originalConfig) @@ -60,10 +89,17 @@ const ModelLoadBalancingModal = ({ provider, model, open = false, onClose, onSav }, [draftConfig]) const extendedSecretFormSchemas = useMemo( - () => provider.provider_credential_schema.credential_form_schemas.filter( - ({ type }) => type === FormTypeEnum.secretInput, - ), - [provider.provider_credential_schema.credential_form_schemas], + () => { + if (providerFormSchemaPredefined) { + return provider?.provider_credential_schema?.credential_form_schemas?.filter( + ({ type }) => type === FormTypeEnum.secretInput, + ) ?? [] + } + return provider?.model_credential_schema?.credential_form_schemas?.filter( + ({ type }) => type === FormTypeEnum.secretInput, + ) ?? [] + }, + [provider?.model_credential_schema?.credential_form_schemas, provider?.provider_credential_schema?.credential_form_schemas, providerFormSchemaPredefined], ) const encodeConfigEntrySecretValues = useCallback((entry: ModelLoadBalancingConfigEntry) => { @@ -75,25 +111,34 @@ const ModelLoadBalancingModal = ({ provider, model, open = false, onClose, onSav return result }, [extendedSecretFormSchemas, originalConfigMap]) + const { mutateAsync: updateModelLoadBalancingConfig } = useUpdateModelLoadBalancingConfig(provider.provider) + const initialCustomModelCredential = useMemo(() => { + if (!current_credential_id) + return undefined + return { + credential_id: current_credential_id, + credential_name: current_credential_name, + } + }, [current_credential_id, current_credential_name]) + const [customModelCredential, setCustomModelCredential] = useState(initialCustomModelCredential) const handleSave = async () => { try { setLoading(true) - const res = await savePredefinedLoadBalancingConfig( - provider.provider, - ({ - ...(data?.credentials ?? {}), - __model_type: model.model_type, - __model_name: model.model, - }), + const res = await updateModelLoadBalancingConfig( { - ...draftConfig, - enabled: Boolean(draftConfig?.enabled), - configs: draftConfig!.configs.map(encodeConfigEntrySecretValues), + credential_id: customModelCredential?.credential_id || current_credential_id, + config_from: configFrom, + model: model.model, + model_type: model.model_type, + load_balancing: { + ...draftConfig, + configs: draftConfig!.configs.map(encodeConfigEntrySecretValues), + enabled: Boolean(draftConfig?.enabled), + }, }, ) if (res.result === 'success') { notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) - mutate() onSave?.(provider.provider) onClose?.() } @@ -110,7 +155,11 @@ const ModelLoadBalancingModal = ({ provider, model, open = false, onClose, onSav className='w-[640px] max-w-none px-8 pt-8' title={
-
{t('common.modelProvider.configLoadBalancing')}
+
{ + draftConfig?.enabled + ? t('common.modelProvider.auth.configLoadBalancing') + : t('common.modelProvider.auth.configModel') + }
{Boolean(model) && (
-
{t('common.modelProvider.providerManaged')}
-
{t('common.modelProvider.providerManagedDescription')}
+
{ + providerFormSchemaPredefined + ? t('common.modelProvider.auth.providerManaged') + : t('common.modelProvider.auth.specifyModelCredential') + }
+
{ + providerFormSchemaPredefined + ? t('common.modelProvider.auth.providerManagedTip') + : t('common.modelProvider.auth.specifyModelCredentialTip') + }
+ { + !providerFormSchemaPredefined && ( + + ) + }
- - + { + modelCredential && ( + + ) + }
@@ -176,6 +253,7 @@ const ModelLoadBalancingModal = ({ provider, model, open = false, onClose, onSav disabled={ loading || (draftConfig?.enabled && (draftConfig?.configs.filter(config => config.enabled).length ?? 0) < 2) + || isLoading } >{t('common.operation.save')}
diff --git a/web/app/components/header/account-setting/model-provider-page/utils.ts b/web/app/components/header/account-setting/model-provider-page/utils.ts index 9056afe69..f577a536d 100644 --- a/web/app/components/header/account-setting/model-provider-page/utils.ts +++ b/web/app/components/header/account-setting/model-provider-page/utils.ts @@ -1,6 +1,5 @@ import { ValidatedStatus } from '../key-validator/declarations' import type { - CredentialFormSchemaRadio, CredentialFormSchemaTextInput, FormValue, ModelLoadBalancingConfig, @@ -82,12 +81,14 @@ export const saveCredentials = async (predefined: boolean, provider: string, v: let body, url if (predefined) { + const { __authorization_name__, ...rest } = v body = { config_from: ConfigurationMethodEnum.predefinedModel, - credentials: v, + credentials: rest, load_balancing: loadBalancing, + name: __authorization_name__, } - url = `/workspaces/current/model-providers/${provider}` + url = `/workspaces/current/model-providers/${provider}/credentials` } else { const { __model_name, __model_type, ...credentials } = v @@ -117,12 +118,17 @@ export const savePredefinedLoadBalancingConfig = async (provider: string, v: For return setModelProvider({ url, body }) } -export const removeCredentials = async (predefined: boolean, provider: string, v: FormValue) => { +export const removeCredentials = async (predefined: boolean, provider: string, v: FormValue, credentialId?: string) => { let url = '' let body if (predefined) { - url = `/workspaces/current/model-providers/${provider}` + url = `/workspaces/current/model-providers/${provider}/credentials` + if (credentialId) { + body = { + credential_id: credentialId, + } + } } else { if (v) { @@ -174,7 +180,7 @@ export const genModelTypeFormSchema = (modelTypes: ModelTypeEnum[]) => { show_on: [], } }), - } as CredentialFormSchemaRadio + } as any } export const genModelNameFormSchema = (model?: Pick) => { @@ -191,5 +197,5 @@ export const genModelNameFormSchema = (model?: Pick void + notAllowCustomCredential?: boolean } const Authorize = ({ pluginPayload, @@ -26,6 +29,7 @@ const Authorize = ({ canApiKey, disabled, onUpdate, + notAllowCustomCredential, }: AuthorizeProps) => { const { t } = useTranslation() const oAuthButtonProps: AddOAuthButtonProps = useMemo(() => { @@ -62,18 +66,54 @@ const Authorize = ({ } }, [canOAuth, theme, pluginPayload, t]) + const OAuthButton = useMemo(() => { + const Item = ( +
+ +
+ ) + + if (notAllowCustomCredential) { + return ( + + {Item} + + ) + } + return Item + }, [notAllowCustomCredential, oAuthButtonProps, disabled, onUpdate, t]) + + const ApiKeyButton = useMemo(() => { + const Item = ( +
+ +
+ ) + + if (notAllowCustomCredential) { + return ( + + {Item} + + ) + } + return Item + }, [notAllowCustomCredential, apiKeyButtonProps, disabled, onUpdate, t]) + return ( <>
{ canOAuth && ( -
- -
+ OAuthButton ) } { @@ -87,13 +127,7 @@ const Authorize = ({ } { canApiKey && ( -
- -
+ ApiKeyButton ) }
diff --git a/web/app/components/plugins/plugin-auth/authorized-in-node.tsx b/web/app/components/plugins/plugin-auth/authorized-in-node.tsx index 79189fa58..79eef6645 100644 --- a/web/app/components/plugins/plugin-auth/authorized-in-node.tsx +++ b/web/app/components/plugins/plugin-auth/authorized-in-node.tsx @@ -35,10 +35,13 @@ const AuthorizedInNode = ({ credentials, disabled, invalidPluginCredentialInfo, + notAllowCustomCredential, } = usePluginAuth(pluginPayload, isOpen || !!credentialId) const renderTrigger = useCallback((open?: boolean) => { let label = '' let removed = false + let unavailable = false + let color = 'green' if (!credentialId) { label = t('plugin.auth.workspaceDefault') } @@ -46,6 +49,12 @@ const AuthorizedInNode = ({ const credential = credentials.find(c => c.id === credentialId) label = credential ? credential.name : t('plugin.auth.authRemoved') removed = !credential + unavailable = !!credential?.not_allowed_to_use && !credential?.from_enterprise + + if (removed) + color = 'red' + else if (unavailable) + color = 'gray' } return ( ) @@ -294,18 +302,24 @@ const Authorized = ({ ) }
-
-
- -
+ { + !notAllowCustomCredential && ( + <> +
+
+ +
+ + ) + }
diff --git a/web/app/components/plugins/plugin-auth/authorized/item.tsx b/web/app/components/plugins/plugin-auth/authorized/item.tsx index 5508bcc32..f8a1033de 100644 --- a/web/app/components/plugins/plugin-auth/authorized/item.tsx +++ b/web/app/components/plugins/plugin-auth/authorized/item.tsx @@ -61,14 +61,19 @@ const Item = ({ return !(disableRename && disableEdit && disableDelete && disableSetDefault) }, [disableRename, disableEdit, disableDelete, disableSetDefault]) - return ( + const CredentialItem = (
onItemClick?.(credential.id === '__workspace_default__' ? '' : credential.id)} + onClick={() => { + if (credential.not_allowed_to_use || disabled) + return + onItemClick?.(credential.id === '__workspace_default__' ? '' : credential.id) + }} > { renaming && ( @@ -121,7 +126,10 @@ const Item = ({
) } - +
) } + { + credential.from_enterprise && ( + + Enterprise + + ) + } { showAction && !renaming && (
{ - !credential.is_default && !disableSetDefault && ( + !credential.is_default && !disableSetDefault && !credential.not_allowed_to_use && ( ) @@ -93,6 +104,7 @@ const PluginAuthInAgent = ({ canApiKey={canApiKey} disabled={disabled} onUpdate={invalidPluginCredentialInfo} + notAllowCustomCredential={notAllowCustomCredential} /> ) } @@ -113,6 +125,7 @@ const PluginAuthInAgent = ({ onOpenChange={setIsOpen} selectedCredentialId={credentialId || '__workspace_default__'} onUpdate={invalidPluginCredentialInfo} + notAllowCustomCredential={notAllowCustomCredential} /> ) } diff --git a/web/app/components/plugins/plugin-auth/plugin-auth.tsx b/web/app/components/plugins/plugin-auth/plugin-auth.tsx index 76b405a75..a9bb287cd 100644 --- a/web/app/components/plugins/plugin-auth/plugin-auth.tsx +++ b/web/app/components/plugins/plugin-auth/plugin-auth.tsx @@ -22,6 +22,7 @@ const PluginAuth = ({ credentials, disabled, invalidPluginCredentialInfo, + notAllowCustomCredential, } = usePluginAuth(pluginPayload, !!pluginPayload.provider) return ( @@ -34,6 +35,7 @@ const PluginAuth = ({ canApiKey={canApiKey} disabled={disabled} onUpdate={invalidPluginCredentialInfo} + notAllowCustomCredential={notAllowCustomCredential} /> ) } @@ -46,6 +48,7 @@ const PluginAuth = ({ canApiKey={canApiKey} disabled={disabled} onUpdate={invalidPluginCredentialInfo} + notAllowCustomCredential={notAllowCustomCredential} /> ) } diff --git a/web/app/components/plugins/plugin-auth/types.ts b/web/app/components/plugins/plugin-auth/types.ts index ad41733bd..1fb2c1a53 100644 --- a/web/app/components/plugins/plugin-auth/types.ts +++ b/web/app/components/plugins/plugin-auth/types.ts @@ -22,4 +22,6 @@ export type Credential = { is_default: boolean credentials?: Record isWorkspaceDefault?: boolean + from_enterprise?: boolean + not_allowed_to_use?: boolean } diff --git a/web/context/modal-context.tsx b/web/context/modal-context.tsx index f1e5bb044..dac9ef30d 100644 --- a/web/context/modal-context.tsx +++ b/web/context/modal-context.tsx @@ -6,7 +6,9 @@ import { createContext, useContext, useContextSelector } from 'use-context-selec import { useRouter, useSearchParams } from 'next/navigation' import type { ConfigurationMethodEnum, + Credential, CustomConfigurationModelFixedFields, + CustomModel, ModelLoadBalancingConfigEntry, ModelProvider, } from '@/app/components/header/account-setting/model-provider-page/declarations' @@ -55,9 +57,6 @@ const ExternalAPIModal = dynamic(() => import('@/app/components/datasets/externa const ModelLoadBalancingModal = dynamic(() => import('@/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal'), { ssr: false, }) -const ModelLoadBalancingEntryModal = dynamic(() => import('@/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal'), { - ssr: false, -}) const OpeningSettingModal = dynamic(() => import('@/app/components/base/features/new-feature-panel/conversation-opener/modal'), { ssr: false, }) @@ -84,6 +83,9 @@ export type ModelModalType = { currentProvider: ModelProvider currentConfigurationMethod: ConfigurationMethodEnum currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields + isModelCredential?: boolean + credential?: Credential + model?: CustomModel } export type LoadBalancingEntryModalType = ModelModalType & { entry?: ModelLoadBalancingConfigEntry @@ -100,7 +102,6 @@ export type ModalContextState = { setShowModelModal: Dispatch | null>> setShowExternalKnowledgeAPIModal: Dispatch | null>> setShowModelLoadBalancingModal: Dispatch> - setShowModelLoadBalancingEntryModal: Dispatch | null>> setShowOpeningModal: Dispatch({ setShowModelModal: noop, setShowExternalKnowledgeAPIModal: noop, setShowModelLoadBalancingModal: noop, - setShowModelLoadBalancingEntryModal: noop, setShowOpeningModal: noop, setShowUpdatePluginModal: noop, setShowEducationExpireNoticeModal: noop, @@ -145,7 +145,6 @@ export const ModalContextProvider = ({ const [showModelModal, setShowModelModal] = useState | null>(null) const [showExternalKnowledgeAPIModal, setShowExternalKnowledgeAPIModal] = useState | null>(null) const [showModelLoadBalancingModal, setShowModelLoadBalancingModal] = useState(null) - const [showModelLoadBalancingEntryModal, setShowModelLoadBalancingEntryModal] = useState | null>(null) const [showOpeningModal, setShowOpeningModal] = useState { - showModelLoadBalancingEntryModal?.onCancelCallback?.() - setShowModelLoadBalancingEntryModal(null) - }, [showModelLoadBalancingEntryModal]) - const handleCancelOpeningModal = useCallback(() => { setShowOpeningModal(null) if (showOpeningModal?.onCancelCallback) showOpeningModal.onCancelCallback() }, [showOpeningModal]) - const handleSaveModelLoadBalancingEntryModal = useCallback((entry: ModelLoadBalancingConfigEntry) => { - showModelLoadBalancingEntryModal?.onSaveCallback?.({ - ...showModelLoadBalancingEntryModal.payload, - entry, - }) - setShowModelLoadBalancingEntryModal(null) - }, [showModelLoadBalancingEntryModal]) - - const handleRemoveModelLoadBalancingEntry = useCallback(() => { - showModelLoadBalancingEntryModal?.onRemoveCallback?.(showModelLoadBalancingEntryModal.payload) - setShowModelLoadBalancingEntryModal(null) - }, [showModelLoadBalancingEntryModal]) - const handleSaveApiBasedExtension = (newApiBasedExtension: ApiBasedExtension) => { if (showApiBasedExtensionModal?.onSaveCallback) showApiBasedExtensionModal.onSaveCallback(newApiBasedExtension) @@ -277,7 +258,6 @@ export const ModalContextProvider = ({ setShowModelModal, setShowExternalKnowledgeAPIModal, setShowModelLoadBalancingModal, - setShowModelLoadBalancingEntryModal, setShowOpeningModal, setShowUpdatePluginModal, setShowEducationExpireNoticeModal, @@ -346,6 +326,9 @@ export const ModalContextProvider = ({ provider={showModelModal.payload.currentProvider} configurateMethod={showModelModal.payload.currentConfigurationMethod} currentCustomConfigurationModelFixedFields={showModelModal.payload.currentCustomConfigurationModelFixedFields} + isModelCredential={showModelModal.payload.isModelCredential} + credential={showModelModal.payload.credential} + model={showModelModal.payload.model} onCancel={handleCancelModelModal} onSave={handleSaveModelModal} /> @@ -368,19 +351,6 @@ export const ModalContextProvider = ({ ) } - { - !!showModelLoadBalancingEntryModal && ( - - ) - } {showOpeningModal && ( { queryFn: () => get<{ data: ModelItem[] }>(`/workspaces/current/model-providers/${provider}/models`), }) } + +export const useGetProviderCredential = (enabled: boolean, provider: string, credentialId?: string) => { + return useQuery({ + enabled, + queryKey: [NAME_SPACE, 'model-list', provider, credentialId], + queryFn: () => get(`/workspaces/current/model-providers/${provider}/credentials${credentialId ? `?credential_id=${credentialId}` : ''}`), + }) +} + +export const useAddProviderCredential = (provider: string) => { + return useMutation({ + mutationFn: (data: ProviderCredential) => post<{ result: string }>(`/workspaces/current/model-providers/${provider}/credentials`, { + body: data, + }), + }) +} + +export const useEditProviderCredential = (provider: string) => { + return useMutation({ + mutationFn: (data: ProviderCredential) => put<{ result: string }>(`/workspaces/current/model-providers/${provider}/credentials`, { + body: data, + }), + }) +} + +export const useDeleteProviderCredential = (provider: string) => { + return useMutation({ + mutationFn: (data: { + credential_id: string + }) => del<{ result: string }>(`/workspaces/current/model-providers/${provider}/credentials`, { + body: data, + }), + }) +} + +export const useActiveProviderCredential = (provider: string) => { + return useMutation({ + mutationFn: (data: { + credential_id: string + model?: string + model_type?: ModelTypeEnum + }) => post<{ result: string }>(`/workspaces/current/model-providers/${provider}/credentials/switch`, { + body: data, + }), + }) +} + +export const useGetModelCredential = ( + enabled: boolean, + provider: string, + credentialId?: string, + model?: string, + modelType?: string, + configFrom?: string, +) => { + return useQuery({ + enabled, + queryKey: [NAME_SPACE, 'model-list', provider, model, modelType, credentialId], + queryFn: () => get(`/workspaces/current/model-providers/${provider}/models/credentials?model=${model}&model_type=${modelType}&config_from=${configFrom}${credentialId ? `&credential_id=${credentialId}` : ''}`), + staleTime: 0, + gcTime: 0, + }) +} + +export const useAddModelCredential = (provider: string) => { + return useMutation({ + mutationFn: (data: ModelCredential) => post<{ result: string }>(`/workspaces/current/model-providers/${provider}/models/credentials`, { + body: data, + }), + }) +} + +export const useEditModelCredential = (provider: string) => { + return useMutation({ + mutationFn: (data: ModelCredential) => put<{ result: string }>(`/workspaces/current/model-providers/${provider}/models/credentials`, { + body: data, + }), + }) +} + +export const useDeleteModelCredential = (provider: string) => { + return useMutation({ + mutationFn: (data: { + credential_id: string + model?: string + model_type?: ModelTypeEnum + }) => del<{ result: string }>(`/workspaces/current/model-providers/${provider}/models/credentials`, { + body: data, + }), + }) +} + +export const useDeleteModel = (provider: string) => { + return useMutation({ + mutationFn: (data: { + model: string + model_type: ModelTypeEnum + }) => del<{ result: string }>(`/workspaces/current/model-providers/${provider}/models/credentials`, { + body: data, + }), + }) +} + +export const useActiveModelCredential = (provider: string) => { + return useMutation({ + mutationFn: (data: { + credential_id: string + model?: string + model_type?: ModelTypeEnum + }) => post<{ result: string }>(`/workspaces/current/model-providers/${provider}/models/credentials/switch`, { + body: data, + }), + }) +} + +export const useUpdateModelLoadBalancingConfig = (provider: string) => { + return useMutation({ + mutationFn: (data: { + config_from: string + model: string + model_type: ModelTypeEnum + load_balancing: ModelLoadBalancingConfig + credential_id?: string + }) => post<{ result: string }>(`/workspaces/current/model-providers/${provider}/models`, { + body: data, + }), + }) +} diff --git a/web/service/use-plugins-auth.ts b/web/service/use-plugins-auth.ts index 2dc026064..51992361e 100644 --- a/web/service/use-plugins-auth.ts +++ b/web/service/use-plugins-auth.ts @@ -19,6 +19,7 @@ export const useGetPluginCredentialInfo = ( enabled: !!url, queryKey: [NAME_SPACE, 'credential-info', url], queryFn: () => get<{ + allow_custom_token?: boolean supported_credential_types: string[] credentials: Credential[] is_oauth_custom_client_enabled: boolean