feat: add multi model credentials (#24451)

Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
非法操作
2025-08-25 16:12:29 +08:00
committed by GitHub
parent b08bfa203a
commit 6010d5f24c
65 changed files with 5202 additions and 1814 deletions

View File

@@ -1,5 +1,6 @@
from datetime import datetime
from enum import Enum
from functools import cached_property
from typing import Optional
import sqlalchemy as sa
@@ -7,6 +8,7 @@ from sqlalchemy import DateTime, String, func, text
from sqlalchemy.orm import Mapped, mapped_column
from .base import Base
from .engine import db
from .types import StringUUID
@@ -60,9 +62,9 @@ class Provider(Base):
provider_type: Mapped[str] = mapped_column(
String(40), nullable=False, server_default=text("'custom'::character varying")
)
encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
quota_type: Mapped[Optional[str]] = mapped_column(
String(40), nullable=True, server_default=text("''::character varying")
@@ -79,6 +81,21 @@ class Provider(Base):
f" provider_type='{self.provider_type}')>"
)
@cached_property
def credential(self):
if self.credential_id:
return db.session.query(ProviderCredential).where(ProviderCredential.id == self.credential_id).first()
@property
def credential_name(self):
credential = self.credential
return credential.credential_name if credential else None
@property
def encrypted_config(self):
credential = self.credential
return credential.encrypted_config if credential else None
@property
def token_is_set(self):
"""
@@ -116,11 +133,30 @@ class ProviderModel(Base):
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@cached_property
def credential(self):
if self.credential_id:
return (
db.session.query(ProviderModelCredential)
.where(ProviderModelCredential.id == self.credential_id)
.first()
)
@property
def credential_name(self):
credential = self.credential
return credential.credential_name if credential else None
@property
def encrypted_config(self):
credential = self.credential
return credential.encrypted_config if credential else None
class TenantDefaultModel(Base):
__tablename__ = "tenant_default_models"
@@ -220,6 +256,56 @@ class LoadBalancingModelConfig(Base):
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
credential_source_type: Mapped[Optional[str]] = mapped_column(String(40), nullable=True)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
class ProviderCredential(Base):
"""
Provider credential - stores multiple named credentials for each provider
"""
__tablename__ = "provider_credentials"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="provider_credential_pkey"),
sa.Index("provider_credential_tenant_provider_idx", "tenant_id", "provider_name"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
class ProviderModelCredential(Base):
"""
Provider model credential - stores multiple named credentials for each provider model
"""
__tablename__ = "provider_model_credentials"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="provider_model_credential_pkey"),
sa.Index(
"provider_model_credential_tenant_provider_model_idx",
"tenant_id",
"provider_name",
"model_name",
"model_type",
),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())