feat: add multi model credentials (#24451)
Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import cached_property
|
||||
from typing import Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
@@ -7,6 +8,7 @@ from sqlalchemy import DateTime, String, func, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
@@ -60,9 +62,9 @@ class Provider(Base):
|
||||
provider_type: Mapped[str] = mapped_column(
|
||||
String(40), nullable=False, server_default=text("'custom'::character varying")
|
||||
)
|
||||
encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
|
||||
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
|
||||
last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
quota_type: Mapped[Optional[str]] = mapped_column(
|
||||
String(40), nullable=True, server_default=text("''::character varying")
|
||||
@@ -79,6 +81,21 @@ class Provider(Base):
|
||||
f" provider_type='{self.provider_type}')>"
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def credential(self):
|
||||
if self.credential_id:
|
||||
return db.session.query(ProviderCredential).where(ProviderCredential.id == self.credential_id).first()
|
||||
|
||||
@property
|
||||
def credential_name(self):
|
||||
credential = self.credential
|
||||
return credential.credential_name if credential else None
|
||||
|
||||
@property
|
||||
def encrypted_config(self):
|
||||
credential = self.credential
|
||||
return credential.encrypted_config if credential else None
|
||||
|
||||
@property
|
||||
def token_is_set(self):
|
||||
"""
|
||||
@@ -116,11 +133,30 @@ class ProviderModel(Base):
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
|
||||
credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
|
||||
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@cached_property
|
||||
def credential(self):
|
||||
if self.credential_id:
|
||||
return (
|
||||
db.session.query(ProviderModelCredential)
|
||||
.where(ProviderModelCredential.id == self.credential_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@property
|
||||
def credential_name(self):
|
||||
credential = self.credential
|
||||
return credential.credential_name if credential else None
|
||||
|
||||
@property
|
||||
def encrypted_config(self):
|
||||
credential = self.credential
|
||||
return credential.encrypted_config if credential else None
|
||||
|
||||
|
||||
class TenantDefaultModel(Base):
|
||||
__tablename__ = "tenant_default_models"
|
||||
@@ -220,6 +256,56 @@ class LoadBalancingModelConfig(Base):
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
|
||||
credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
|
||||
credential_source_type: Mapped[Optional[str]] = mapped_column(String(40), nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class ProviderCredential(Base):
|
||||
"""
|
||||
Provider credential - stores multiple named credentials for each provider
|
||||
"""
|
||||
|
||||
__tablename__ = "provider_credentials"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="provider_credential_pkey"),
|
||||
sa.Index("provider_credential_tenant_provider_idx", "tenant_id", "provider_name"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class ProviderModelCredential(Base):
|
||||
"""
|
||||
Provider model credential - stores multiple named credentials for each provider model
|
||||
"""
|
||||
|
||||
__tablename__ = "provider_model_credentials"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="provider_model_credential_pkey"),
|
||||
sa.Index(
|
||||
"provider_model_credential_tenant_provider_model_idx",
|
||||
"tenant_id",
|
||||
"provider_name",
|
||||
"model_name",
|
||||
"model_type",
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
Reference in New Issue
Block a user