diff --git a/api/commands.py b/api/commands.py index 79bb6713d..8177f1a48 100644 --- a/api/commands.py +++ b/api/commands.py @@ -5,6 +5,7 @@ import secrets from typing import Any, Optional import click +import sqlalchemy as sa from flask import current_app from pydantic import TypeAdapter from sqlalchemy import select @@ -457,7 +458,7 @@ def convert_to_agent_apps(): """ with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query)) + rs = conn.execute(sa.text(sql_query)) apps = [] for i in rs: @@ -702,7 +703,7 @@ def fix_app_site_missing(): sql = """select apps.id as id from apps left join sites on sites.app_id=apps.id where sites.id is null limit 1000""" with db.engine.begin() as conn: - rs = conn.execute(db.text(sql)) + rs = conn.execute(sa.text(sql)) processed_count = 0 for i in rs: @@ -916,7 +917,7 @@ def clear_orphaned_file_records(force: bool): ) orphaned_message_files = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(query)) + rs = conn.execute(sa.text(query)) for i in rs: orphaned_message_files.append({"id": str(i[0]), "message_id": str(i[1])}) @@ -937,7 +938,7 @@ def clear_orphaned_file_records(force: bool): click.echo(click.style("- Deleting orphaned message_files records", fg="white")) query = "DELETE FROM message_files WHERE id IN :ids" with db.engine.begin() as conn: - conn.execute(db.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])}) + conn.execute(sa.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])}) click.echo( click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green") ) @@ -954,7 +955,7 @@ def clear_orphaned_file_records(force: bool): click.echo(click.style(f"- Listing file records in table {files_table['table']}", fg="white")) query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}" with db.engine.begin() as conn: - rs = conn.execute(db.text(query)) + rs = conn.execute(sa.text(query)) for i in rs: all_files_in_tables.append({"table": files_table["table"], "id": str(i[0]), "key": i[1]}) click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white")) @@ -974,7 +975,7 @@ def clear_orphaned_file_records(force: bool): f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" ) with db.engine.begin() as conn: - rs = conn.execute(db.text(query)) + rs = conn.execute(sa.text(query)) for i in rs: all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])}) elif ids_table["type"] == "text": @@ -989,7 +990,7 @@ def clear_orphaned_file_records(force: bool): f"FROM {ids_table['table']}" ) with db.engine.begin() as conn: - rs = conn.execute(db.text(query)) + rs = conn.execute(sa.text(query)) for i in rs: for j in i[0]: all_ids_in_tables.append({"table": ids_table["table"], "id": j}) @@ -1008,7 +1009,7 @@ def clear_orphaned_file_records(force: bool): f"FROM {ids_table['table']}" ) with db.engine.begin() as conn: - rs = conn.execute(db.text(query)) + rs = conn.execute(sa.text(query)) for i in rs: for j in i[0]: all_ids_in_tables.append({"table": ids_table["table"], "id": j}) @@ -1037,7 +1038,7 @@ def clear_orphaned_file_records(force: bool): click.echo(click.style(f"- Deleting orphaned file records in table {files_table['table']}", fg="white")) query = f"DELETE FROM {files_table['table']} WHERE {files_table['id_column']} IN :ids" with db.engine.begin() as conn: - conn.execute(db.text(query), {"ids": tuple(orphaned_files)}) + conn.execute(sa.text(query), {"ids": tuple(orphaned_files)}) except Exception as e: click.echo(click.style(f"Error deleting orphaned file records: {str(e)}", fg="red")) return @@ -1107,7 +1108,7 @@ def remove_orphaned_files_on_storage(force: bool): click.echo(click.style(f"- Listing files from table {files_table['table']}", fg="white")) query = f"SELECT {files_table['key_column']} FROM {files_table['table']}" with db.engine.begin() as conn: - rs = conn.execute(db.text(query)) + rs = conn.execute(sa.text(query)) for i in rs: all_files_in_tables.append(str(i[0])) click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white")) diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 32b64d10c..343b7acd7 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -67,7 +67,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "message_count": i.message_count}) @@ -176,7 +176,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "terminal_count": i.terminal_count}) @@ -234,7 +234,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append( {"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"} @@ -310,7 +310,7 @@ ORDER BY response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append( {"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))} @@ -373,7 +373,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append( { @@ -435,7 +435,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)}) @@ -495,7 +495,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)}) diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 6c7c73707..7f80afd83 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -2,6 +2,7 @@ from datetime import datetime from decimal import Decimal import pytz +import sqlalchemy as sa from flask import jsonify from flask_login import current_user from flask_restful import Resource, reqparse @@ -71,7 +72,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "runs": i.runs}) @@ -133,7 +134,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "terminal_count": i.terminal_count}) @@ -195,7 +196,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append( { @@ -277,7 +278,7 @@ GROUP BY response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append( {"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))} diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 1bb4cfa4c..2737bcfb1 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -7,6 +7,7 @@ from os import listdir, path from threading import Lock from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast +import sqlalchemy as sa from pydantic import TypeAdapter from yarl import URL @@ -616,7 +617,7 @@ class ToolManager: WHERE tenant_id = :tenant_id ORDER BY tenant_id, provider, is_default DESC, created_at DESC """ - ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()] + ids = [row.id for row in db.session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()] return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all() @classmethod diff --git a/api/models/account.py b/api/models/account.py index 343705589..1a0752440 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -3,6 +3,7 @@ import json from datetime import datetime from typing import Optional, cast +import sqlalchemy as sa from flask_login import UserMixin # type: ignore from sqlalchemy import DateTime, String, func, select from sqlalchemy.orm import Mapped, mapped_column, reconstructor @@ -83,9 +84,9 @@ class AccountStatus(enum.StrEnum): class Account(UserMixin, Base): __tablename__ = "accounts" - __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) + __table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email")) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) name: Mapped[str] = mapped_column(String(255)) email: Mapped[str] = mapped_column(String(255)) password: Mapped[Optional[str]] = mapped_column(String(255)) @@ -97,7 +98,7 @@ class Account(UserMixin, Base): last_login_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) last_login_ip: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) - status: Mapped[str] = mapped_column(String(16), server_default=db.text("'active'::character varying")) + status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'::character varying")) initialized_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) @@ -195,14 +196,14 @@ class TenantStatus(enum.StrEnum): class Tenant(Base): __tablename__ = "tenants" - __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) + __table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) name: Mapped[str] = mapped_column(String(255)) - encrypt_public_key = db.Column(db.Text) - plan: Mapped[str] = mapped_column(String(255), server_default=db.text("'basic'::character varying")) - status: Mapped[str] = mapped_column(String(255), server_default=db.text("'normal'::character varying")) - custom_config: Mapped[Optional[str]] = mapped_column(db.Text) + encrypt_public_key = db.Column(sa.Text) + plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'::character varying")) + status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying")) + custom_config: Mapped[Optional[str]] = mapped_column(sa.Text) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) @@ -225,16 +226,16 @@ class Tenant(Base): class TenantAccountJoin(Base): __tablename__ = "tenant_account_joins" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), - db.Index("tenant_account_join_account_id_idx", "account_id"), - db.Index("tenant_account_join_tenant_id_idx", "tenant_id"), - db.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"), + sa.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), + sa.Index("tenant_account_join_account_id_idx", "account_id"), + sa.Index("tenant_account_join_tenant_id_idx", "tenant_id"), + sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) account_id: Mapped[str] = mapped_column(StringUUID) - current: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) + current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) role: Mapped[str] = mapped_column(String(16), server_default="normal") invited_by: Mapped[Optional[str]] = mapped_column(StringUUID) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) @@ -244,12 +245,12 @@ class TenantAccountJoin(Base): class AccountIntegrate(Base): __tablename__ = "account_integrates" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="account_integrate_pkey"), - db.UniqueConstraint("account_id", "provider", name="unique_account_provider"), - db.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"), + sa.PrimaryKeyConstraint("id", name="account_integrate_pkey"), + sa.UniqueConstraint("account_id", "provider", name="unique_account_provider"), + sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) account_id: Mapped[str] = mapped_column(StringUUID) provider: Mapped[str] = mapped_column(String(16)) open_id: Mapped[str] = mapped_column(String(255)) @@ -261,20 +262,20 @@ class AccountIntegrate(Base): class InvitationCode(Base): __tablename__ = "invitation_codes" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="invitation_code_pkey"), - db.Index("invitation_codes_batch_idx", "batch"), - db.Index("invitation_codes_code_idx", "code", "status"), + sa.PrimaryKeyConstraint("id", name="invitation_code_pkey"), + sa.Index("invitation_codes_batch_idx", "batch"), + sa.Index("invitation_codes_code_idx", "code", "status"), ) - id: Mapped[int] = mapped_column(db.Integer) + id: Mapped[int] = mapped_column(sa.Integer) batch: Mapped[str] = mapped_column(String(255)) code: Mapped[str] = mapped_column(String(32)) - status: Mapped[str] = mapped_column(String(16), server_default=db.text("'unused'::character varying")) + status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'::character varying")) used_at: Mapped[Optional[datetime]] = mapped_column(DateTime) used_by_tenant_id: Mapped[Optional[str]] = mapped_column(StringUUID) used_by_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) deprecated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) - created_at: Mapped[datetime] = mapped_column(DateTime, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)")) class TenantPluginPermission(Base): @@ -290,11 +291,11 @@ class TenantPluginPermission(Base): __tablename__ = "account_plugin_permissions" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="account_plugin_permission_pkey"), - db.UniqueConstraint("tenant_id", name="unique_tenant_plugin"), + sa.PrimaryKeyConstraint("id", name="account_plugin_permission_pkey"), + sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) install_permission: Mapped[InstallPermission] = mapped_column(String(16), nullable=False, server_default="everyone") debug_permission: Mapped[DebugPermission] = mapped_column(String(16), nullable=False, server_default="noone") @@ -313,16 +314,16 @@ class TenantPluginAutoUpgradeStrategy(Base): __tablename__ = "tenant_plugin_auto_upgrade_strategies" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tenant_plugin_auto_upgrade_strategy_pkey"), - db.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"), + sa.PrimaryKeyConstraint("id", name="tenant_plugin_auto_upgrade_strategy_pkey"), + sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) strategy_setting: Mapped[StrategySetting] = mapped_column(String(16), nullable=False, server_default="fix_only") - upgrade_time_of_day: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) # seconds of the day + upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) # seconds of the day upgrade_mode: Mapped[UpgradeMode] = mapped_column(String(16), nullable=False, server_default="exclude") - exclude_plugins: Mapped[list[str]] = mapped_column(db.ARRAY(String(255)), nullable=False) # plugin_id (author/name) - include_plugins: Mapped[list[str]] = mapped_column(db.ARRAY(String(255)), nullable=False) # plugin_id (author/name) + exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name) + include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name) created_at = db.Column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index ac9eda682..60167d906 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -1,11 +1,11 @@ import enum from datetime import datetime +import sqlalchemy as sa from sqlalchemy import DateTime, String, Text, func from sqlalchemy.orm import Mapped, mapped_column from .base import Base -from .engine import db from .types import StringUUID @@ -19,11 +19,11 @@ class APIBasedExtensionPoint(enum.Enum): class APIBasedExtension(Base): __tablename__ = "api_based_extensions" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"), - db.Index("api_based_extension_tenant_idx", "tenant_id"), + sa.PrimaryKeyConstraint("id", name="api_based_extension_pkey"), + sa.Index("api_based_extension_tenant_idx", "tenant_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) api_endpoint: Mapped[str] = mapped_column(String(255), nullable=False) diff --git a/api/models/dataset.py b/api/models/dataset.py index e62101ae7..3b1d289bc 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -12,6 +12,7 @@ from datetime import datetime from json import JSONDecodeError from typing import Any, Optional, cast +import sqlalchemy as sa from sqlalchemy import DateTime, String, func, select from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column @@ -38,23 +39,23 @@ class DatasetPermissionEnum(enum.StrEnum): class Dataset(Base): __tablename__ = "datasets" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_pkey"), - db.Index("dataset_tenant_idx", "tenant_id"), - db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"), + sa.PrimaryKeyConstraint("id", name="dataset_pkey"), + sa.Index("dataset_tenant_idx", "tenant_id"), + sa.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"), ) INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None] PROVIDER_LIST = ["vendor", "external", None] - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) name: Mapped[str] = mapped_column(String(255)) - description = mapped_column(db.Text, nullable=True) - provider: Mapped[str] = mapped_column(String(255), server_default=db.text("'vendor'::character varying")) - permission: Mapped[str] = mapped_column(String(255), server_default=db.text("'only_me'::character varying")) + description = mapped_column(sa.Text, nullable=True) + provider: Mapped[str] = mapped_column(String(255), server_default=sa.text("'vendor'::character varying")) + permission: Mapped[str] = mapped_column(String(255), server_default=sa.text("'only_me'::character varying")) data_source_type = mapped_column(String(255)) indexing_technique: Mapped[Optional[str]] = mapped_column(String(255)) - index_struct = mapped_column(db.Text, nullable=True) + index_struct = mapped_column(sa.Text, nullable=True) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) @@ -63,7 +64,7 @@ class Dataset(Base): embedding_model_provider = db.Column(String(255), nullable=True) # TODO: mapped_column collection_binding_id = mapped_column(StringUUID, nullable=True) retrieval_model = mapped_column(JSONB, nullable=True) - built_in_field_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + built_in_field_enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) @property def dataset_keyword_table(self): @@ -262,14 +263,14 @@ class Dataset(Base): class DatasetProcessRule(Base): __tablename__ = "dataset_process_rules" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), - db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"), + sa.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), + sa.Index("dataset_process_rule_dataset_id_idx", "dataset_id"), ) - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) dataset_id = mapped_column(StringUUID, nullable=False) - mode = mapped_column(String(255), nullable=False, server_default=db.text("'automatic'::character varying")) - rules = mapped_column(db.Text, nullable=True) + mode = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'::character varying")) + rules = mapped_column(sa.Text, nullable=True) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -302,20 +303,20 @@ class DatasetProcessRule(Base): class Document(Base): __tablename__ = "documents" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="document_pkey"), - db.Index("document_dataset_id_idx", "dataset_id"), - db.Index("document_is_paused_idx", "is_paused"), - db.Index("document_tenant_idx", "tenant_id"), - db.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"), + sa.PrimaryKeyConstraint("id", name="document_pkey"), + sa.Index("document_dataset_id_idx", "dataset_id"), + sa.Index("document_is_paused_idx", "is_paused"), + sa.Index("document_tenant_idx", "tenant_id"), + sa.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"), ) # initial fields - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - position: Mapped[int] = mapped_column(db.Integer, nullable=False) + position: Mapped[int] = mapped_column(sa.Integer, nullable=False) data_source_type: Mapped[str] = mapped_column(String(255), nullable=False) - data_source_info = mapped_column(db.Text, nullable=True) + data_source_info = mapped_column(sa.Text, nullable=True) dataset_process_rule_id = mapped_column(StringUUID, nullable=True) batch: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) @@ -328,8 +329,8 @@ class Document(Base): processing_started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # parsing - file_id = mapped_column(db.Text, nullable=True) - word_count: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) # TODO: make this not nullable + file_id = mapped_column(sa.Text, nullable=True) + word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) # TODO: make this not nullable parsing_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # cleaning @@ -339,32 +340,32 @@ class Document(Base): splitting_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # indexing - tokens: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) - indexing_latency: Mapped[Optional[float]] = mapped_column(db.Float, nullable=True) + tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + indexing_latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True) completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # pause - is_paused: Mapped[Optional[bool]] = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) + is_paused: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) paused_by = mapped_column(StringUUID, nullable=True) paused_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # error - error = mapped_column(db.Text, nullable=True) + error = mapped_column(sa.Text, nullable=True) stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # basic fields - indexing_status = mapped_column(String(255), nullable=False, server_default=db.text("'waiting'::character varying")) - enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + indexing_status = mapped_column(String(255), nullable=False, server_default=sa.text("'waiting'::character varying")) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) - archived: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + archived: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) archived_reason = mapped_column(String(255), nullable=True) archived_by = mapped_column(StringUUID, nullable=True) archived_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) doc_type = mapped_column(String(40), nullable=True) doc_metadata = mapped_column(JSONB, nullable=True) - doc_form = mapped_column(String(255), nullable=False, server_default=db.text("'text_model'::character varying")) + doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'::character varying")) doc_language = mapped_column(String(255), nullable=True) DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] @@ -643,44 +644,44 @@ class Document(Base): class DocumentSegment(Base): __tablename__ = "document_segments" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="document_segment_pkey"), - db.Index("document_segment_dataset_id_idx", "dataset_id"), - db.Index("document_segment_document_id_idx", "document_id"), - db.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"), - db.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"), - db.Index("document_segment_node_dataset_idx", "index_node_id", "dataset_id"), - db.Index("document_segment_tenant_idx", "tenant_id"), + sa.PrimaryKeyConstraint("id", name="document_segment_pkey"), + sa.Index("document_segment_dataset_id_idx", "dataset_id"), + sa.Index("document_segment_document_id_idx", "document_id"), + sa.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"), + sa.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"), + sa.Index("document_segment_node_dataset_idx", "index_node_id", "dataset_id"), + sa.Index("document_segment_tenant_idx", "tenant_id"), ) # initial fields - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) document_id = mapped_column(StringUUID, nullable=False) position: Mapped[int] - content = mapped_column(db.Text, nullable=False) - answer = mapped_column(db.Text, nullable=True) + content = mapped_column(sa.Text, nullable=False) + answer = mapped_column(sa.Text, nullable=True) word_count: Mapped[int] tokens: Mapped[int] # indexing fields - keywords = mapped_column(db.JSON, nullable=True) + keywords = mapped_column(sa.JSON, nullable=True) index_node_id = mapped_column(String(255), nullable=True) index_node_hash = mapped_column(String(255), nullable=True) # basic fields - hit_count: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) - enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) - status: Mapped[str] = mapped_column(String(255), server_default=db.text("'waiting'::character varying")) + status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'waiting'::character varying")) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) - error = mapped_column(db.Text, nullable=True) + error = mapped_column(sa.Text, nullable=True) stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) @property @@ -794,36 +795,36 @@ class DocumentSegment(Base): class ChildChunk(Base): __tablename__ = "child_chunks" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="child_chunk_pkey"), - db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"), - db.Index("child_chunks_node_idx", "index_node_id", "dataset_id"), - db.Index("child_chunks_segment_idx", "segment_id"), + sa.PrimaryKeyConstraint("id", name="child_chunk_pkey"), + sa.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"), + sa.Index("child_chunks_node_idx", "index_node_id", "dataset_id"), + sa.Index("child_chunks_segment_idx", "segment_id"), ) # initial fields - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) document_id = mapped_column(StringUUID, nullable=False) segment_id = mapped_column(StringUUID, nullable=False) - position: Mapped[int] = mapped_column(db.Integer, nullable=False) - content = mapped_column(db.Text, nullable=False) - word_count: Mapped[int] = mapped_column(db.Integer, nullable=False) + position: Mapped[int] = mapped_column(sa.Integer, nullable=False) + content = mapped_column(sa.Text, nullable=False) + word_count: Mapped[int] = mapped_column(sa.Integer, nullable=False) # indexing fields index_node_id = mapped_column(String(255), nullable=True) index_node_hash = mapped_column(String(255), nullable=True) - type = mapped_column(String(255), nullable=False, server_default=db.text("'automatic'::character varying")) + type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'::character varying")) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) updated_by = mapped_column(StringUUID, nullable=True) updated_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) - error = mapped_column(db.Text, nullable=True) + error = mapped_column(sa.Text, nullable=True) @property def dataset(self): @@ -841,11 +842,11 @@ class ChildChunk(Base): class AppDatasetJoin(Base): __tablename__ = "app_dataset_joins" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), - db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"), + sa.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), + sa.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"), ) - id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp()) @@ -858,13 +859,13 @@ class AppDatasetJoin(Base): class DatasetQuery(Base): __tablename__ = "dataset_queries" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_query_pkey"), - db.Index("dataset_query_dataset_id_idx", "dataset_id"), + sa.PrimaryKeyConstraint("id", name="dataset_query_pkey"), + sa.Index("dataset_query_dataset_id_idx", "dataset_id"), ) - id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()")) dataset_id = mapped_column(StringUUID, nullable=False) - content = mapped_column(db.Text, nullable=False) + content = mapped_column(sa.Text, nullable=False) source: Mapped[str] = mapped_column(String(255), nullable=False) source_app_id = mapped_column(StringUUID, nullable=True) created_by_role = mapped_column(String, nullable=False) @@ -875,15 +876,15 @@ class DatasetQuery(Base): class DatasetKeywordTable(Base): __tablename__ = "dataset_keyword_tables" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), - db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"), + sa.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), + sa.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"), ) - id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) dataset_id = mapped_column(StringUUID, nullable=False, unique=True) - keyword_table = mapped_column(db.Text, nullable=False) + keyword_table = mapped_column(sa.Text, nullable=False) data_source_type = mapped_column( - String(255), nullable=False, server_default=db.text("'database'::character varying") + String(255), nullable=False, server_default=sa.text("'database'::character varying") ) @property @@ -920,19 +921,19 @@ class DatasetKeywordTable(Base): class Embedding(Base): __tablename__ = "embeddings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="embedding_pkey"), - db.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"), - db.Index("created_at_idx", "created_at"), + sa.PrimaryKeyConstraint("id", name="embedding_pkey"), + sa.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"), + sa.Index("created_at_idx", "created_at"), ) - id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) model_name = mapped_column( - String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying") + String(255), nullable=False, server_default=sa.text("'text-embedding-ada-002'::character varying") ) hash = mapped_column(String(64), nullable=False) - embedding = mapped_column(db.LargeBinary, nullable=False) + embedding = mapped_column(sa.LargeBinary, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - provider_name = mapped_column(String(255), nullable=False, server_default=db.text("''::character varying")) + provider_name = mapped_column(String(255), nullable=False, server_default=sa.text("''::character varying")) def set_embedding(self, embedding_data: list[float]): self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) @@ -944,14 +945,14 @@ class Embedding(Base): class DatasetCollectionBinding(Base): __tablename__ = "dataset_collection_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), - db.Index("provider_model_name_idx", "provider_name", "model_name"), + sa.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), + sa.Index("provider_model_name_idx", "provider_name", "model_name"), ) - id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) provider_name: Mapped[str] = mapped_column(String(255), nullable=False) model_name: Mapped[str] = mapped_column(String(255), nullable=False) - type = mapped_column(String(40), server_default=db.text("'dataset'::character varying"), nullable=False) + type = mapped_column(String(40), server_default=sa.text("'dataset'::character varying"), nullable=False) collection_name = mapped_column(String(64), nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -959,17 +960,17 @@ class DatasetCollectionBinding(Base): class TidbAuthBinding(Base): __tablename__ = "tidb_auth_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), - db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"), - db.Index("tidb_auth_bindings_active_idx", "active"), - db.Index("tidb_auth_bindings_created_at_idx", "created_at"), - db.Index("tidb_auth_bindings_status_idx", "status"), + sa.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), + sa.Index("tidb_auth_bindings_tenant_idx", "tenant_id"), + sa.Index("tidb_auth_bindings_active_idx", "active"), + sa.Index("tidb_auth_bindings_created_at_idx", "created_at"), + sa.Index("tidb_auth_bindings_status_idx", "status"), ) - id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=True) cluster_id: Mapped[str] = mapped_column(String(255), nullable=False) cluster_name: Mapped[str] = mapped_column(String(255), nullable=False) - active: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false")) status = mapped_column(String(255), nullable=False, server_default=db.text("'CREATING'::character varying")) account: Mapped[str] = mapped_column(String(255), nullable=False) password: Mapped[str] = mapped_column(String(255), nullable=False) @@ -979,10 +980,10 @@ class TidbAuthBinding(Base): class Whitelist(Base): __tablename__ = "whitelists" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="whitelists_pkey"), - db.Index("whitelists_tenant_idx", "tenant_id"), + sa.PrimaryKeyConstraint("id", name="whitelists_pkey"), + sa.Index("whitelists_tenant_idx", "tenant_id"), ) - id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=True) category: Mapped[str] = mapped_column(String(255), nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -991,33 +992,33 @@ class Whitelist(Base): class DatasetPermission(Base): __tablename__ = "dataset_permissions" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), - db.Index("idx_dataset_permissions_dataset_id", "dataset_id"), - db.Index("idx_dataset_permissions_account_id", "account_id"), - db.Index("idx_dataset_permissions_tenant_id", "tenant_id"), + sa.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), + sa.Index("idx_dataset_permissions_dataset_id", "dataset_id"), + sa.Index("idx_dataset_permissions_account_id", "account_id"), + sa.Index("idx_dataset_permissions_tenant_id", "tenant_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), primary_key=True) dataset_id = mapped_column(StringUUID, nullable=False) account_id = mapped_column(StringUUID, nullable=False) tenant_id = mapped_column(StringUUID, nullable=False) - has_permission: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + has_permission: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class ExternalKnowledgeApis(Base): __tablename__ = "external_knowledge_apis" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"), - db.Index("external_knowledge_apis_tenant_idx", "tenant_id"), - db.Index("external_knowledge_apis_name_idx", "name"), + sa.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"), + sa.Index("external_knowledge_apis_tenant_idx", "tenant_id"), + sa.Index("external_knowledge_apis_name_idx", "name"), ) - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) name: Mapped[str] = mapped_column(String(255), nullable=False) description: Mapped[str] = mapped_column(String(255), nullable=False) tenant_id = mapped_column(StringUUID, nullable=False) - settings = mapped_column(db.Text, nullable=True) + settings = mapped_column(sa.Text, nullable=True) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) @@ -1061,18 +1062,18 @@ class ExternalKnowledgeApis(Base): class ExternalKnowledgeBindings(Base): __tablename__ = "external_knowledge_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"), - db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"), - db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"), - db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"), - db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"), + sa.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"), + sa.Index("external_knowledge_bindings_tenant_idx", "tenant_id"), + sa.Index("external_knowledge_bindings_dataset_idx", "dataset_id"), + sa.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"), + sa.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"), ) - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) external_knowledge_api_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - external_knowledge_id = mapped_column(db.Text, nullable=False) + external_knowledge_id = mapped_column(sa.Text, nullable=False) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) @@ -1082,57 +1083,57 @@ class ExternalKnowledgeBindings(Base): class DatasetAutoDisableLog(Base): __tablename__ = "dataset_auto_disable_logs" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"), - db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"), - db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"), - db.Index("dataset_auto_disable_log_created_atx", "created_at"), + sa.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"), + sa.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"), + sa.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"), + sa.Index("dataset_auto_disable_log_created_atx", "created_at"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) document_id = mapped_column(StringUUID, nullable=False) - notified: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + notified: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) created_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) class RateLimitLog(Base): __tablename__ = "rate_limit_logs" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"), - db.Index("rate_limit_log_tenant_idx", "tenant_id"), - db.Index("rate_limit_log_operation_idx", "operation"), + sa.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"), + sa.Index("rate_limit_log_tenant_idx", "tenant_id"), + sa.Index("rate_limit_log_operation_idx", "operation"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) subscription_plan: Mapped[str] = mapped_column(String(255), nullable=False) operation: Mapped[str] = mapped_column(String(255), nullable=False) created_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) class DatasetMetadata(Base): __tablename__ = "dataset_metadatas" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"), - db.Index("dataset_metadata_tenant_idx", "tenant_id"), - db.Index("dataset_metadata_dataset_idx", "dataset_id"), + sa.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"), + sa.Index("dataset_metadata_tenant_idx", "tenant_id"), + sa.Index("dataset_metadata_dataset_idx", "dataset_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) type: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) created_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) updated_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) created_by = mapped_column(StringUUID, nullable=False) updated_by = mapped_column(StringUUID, nullable=True) @@ -1141,14 +1142,14 @@ class DatasetMetadata(Base): class DatasetMetadataBinding(Base): __tablename__ = "dataset_metadata_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"), - db.Index("dataset_metadata_binding_tenant_idx", "tenant_id"), - db.Index("dataset_metadata_binding_dataset_idx", "dataset_id"), - db.Index("dataset_metadata_binding_metadata_idx", "metadata_id"), - db.Index("dataset_metadata_binding_document_idx", "document_id"), + sa.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"), + sa.Index("dataset_metadata_binding_tenant_idx", "tenant_id"), + sa.Index("dataset_metadata_binding_dataset_idx", "dataset_id"), + sa.Index("dataset_metadata_binding_metadata_idx", "metadata_id"), + sa.Index("dataset_metadata_binding_document_idx", "document_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) metadata_id = mapped_column(StringUUID, nullable=False) diff --git a/api/models/model.py b/api/models/model.py index fba0d692e..c4303f3cc 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -35,10 +35,10 @@ from .types import StringUUID class DifySetup(Base): __tablename__ = "dify_setups" - __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) + __table_args__ = (sa.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) version: Mapped[str] = mapped_column(String(255), nullable=False) - setup_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + setup_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class AppMode(StrEnum): @@ -69,33 +69,33 @@ class IconType(Enum): class App(Base): __tablename__ = "apps" - __table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id")) + __table_args__ = (sa.PrimaryKeyConstraint("id", name="app_pkey"), sa.Index("app_tenant_id_idx", "tenant_id")) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) name: Mapped[str] = mapped_column(String(255)) - description: Mapped[str] = mapped_column(db.Text, server_default=db.text("''::character varying")) + description: Mapped[str] = mapped_column(sa.Text, server_default=sa.text("''::character varying")) mode: Mapped[str] = mapped_column(String(255)) icon_type: Mapped[Optional[str]] = mapped_column(String(255)) # image, emoji icon = db.Column(String(255)) icon_background: Mapped[Optional[str]] = mapped_column(String(255)) app_model_config_id = mapped_column(StringUUID, nullable=True) workflow_id = mapped_column(StringUUID, nullable=True) - status: Mapped[str] = mapped_column(String(255), server_default=db.text("'normal'::character varying")) - enable_site: Mapped[bool] = mapped_column(db.Boolean) - enable_api: Mapped[bool] = mapped_column(db.Boolean) - api_rpm: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0")) - api_rph: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0")) - is_demo: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) - is_public: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) - is_universal: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) - tracing = mapped_column(db.Text, nullable=True) + status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying")) + enable_site: Mapped[bool] = mapped_column(sa.Boolean) + enable_api: Mapped[bool] = mapped_column(sa.Boolean) + api_rpm: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0")) + api_rph: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0")) + is_demo: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) + is_public: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) + is_universal: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) + tracing = mapped_column(sa.Text, nullable=True) max_active_requests: Mapped[Optional[int]] created_by = mapped_column(StringUUID, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - use_icon_as_answer_icon: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) @property def desc_or_prompt(self): @@ -302,36 +302,36 @@ class App(Base): class AppModelConfig(Base): __tablename__ = "app_model_configs" - __table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id")) + __table_args__ = (sa.PrimaryKeyConstraint("id", name="app_model_config_pkey"), sa.Index("app_app_id_idx", "app_id")) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) provider = mapped_column(String(255), nullable=True) model_id = mapped_column(String(255), nullable=True) - configs = mapped_column(db.JSON, nullable=True) + configs = mapped_column(sa.JSON, nullable=True) created_by = mapped_column(StringUUID, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - opening_statement = mapped_column(db.Text) - suggested_questions = mapped_column(db.Text) - suggested_questions_after_answer = mapped_column(db.Text) - speech_to_text = mapped_column(db.Text) - text_to_speech = mapped_column(db.Text) - more_like_this = mapped_column(db.Text) - model = mapped_column(db.Text) - user_input_form = mapped_column(db.Text) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + opening_statement = mapped_column(sa.Text) + suggested_questions = mapped_column(sa.Text) + suggested_questions_after_answer = mapped_column(sa.Text) + speech_to_text = mapped_column(sa.Text) + text_to_speech = mapped_column(sa.Text) + more_like_this = mapped_column(sa.Text) + model = mapped_column(sa.Text) + user_input_form = mapped_column(sa.Text) dataset_query_variable = mapped_column(String(255)) - pre_prompt = mapped_column(db.Text) - agent_mode = mapped_column(db.Text) - sensitive_word_avoidance = mapped_column(db.Text) - retriever_resource = mapped_column(db.Text) - prompt_type = mapped_column(String(255), nullable=False, server_default=db.text("'simple'::character varying")) - chat_prompt_config = mapped_column(db.Text) - completion_prompt_config = mapped_column(db.Text) - dataset_configs = mapped_column(db.Text) - external_data_tools = mapped_column(db.Text) - file_upload = mapped_column(db.Text) + pre_prompt = mapped_column(sa.Text) + agent_mode = mapped_column(sa.Text) + sensitive_word_avoidance = mapped_column(sa.Text) + retriever_resource = mapped_column(sa.Text) + prompt_type = mapped_column(String(255), nullable=False, server_default=sa.text("'simple'::character varying")) + chat_prompt_config = mapped_column(sa.Text) + completion_prompt_config = mapped_column(sa.Text) + dataset_configs = mapped_column(sa.Text) + external_data_tools = mapped_column(sa.Text) + file_upload = mapped_column(sa.Text) @property def app(self): @@ -553,24 +553,24 @@ class AppModelConfig(Base): class RecommendedApp(Base): __tablename__ = "recommended_apps" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="recommended_app_pkey"), - db.Index("recommended_app_app_id_idx", "app_id"), - db.Index("recommended_app_is_listed_idx", "is_listed", "language"), + sa.PrimaryKeyConstraint("id", name="recommended_app_pkey"), + sa.Index("recommended_app_app_id_idx", "app_id"), + sa.Index("recommended_app_is_listed_idx", "is_listed", "language"), ) - id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - description = mapped_column(db.JSON, nullable=False) + description = mapped_column(sa.JSON, nullable=False) copyright: Mapped[str] = mapped_column(String(255), nullable=False) privacy_policy: Mapped[str] = mapped_column(String(255), nullable=False) custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") category: Mapped[str] = mapped_column(String(255), nullable=False) - position: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) - is_listed: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=True) - install_count: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) - language = mapped_column(String(255), nullable=False, server_default=db.text("'en-US'::character varying")) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) + is_listed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True) + install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) + language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying")) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def app(self): @@ -581,20 +581,20 @@ class RecommendedApp(Base): class InstalledApp(Base): __tablename__ = "installed_apps" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="installed_app_pkey"), - db.Index("installed_app_tenant_id_idx", "tenant_id"), - db.Index("installed_app_app_id_idx", "app_id"), - db.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"), + sa.PrimaryKeyConstraint("id", name="installed_app_pkey"), + sa.Index("installed_app_tenant_id_idx", "tenant_id"), + sa.Index("installed_app_app_id_idx", "app_id"), + sa.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=False) app_owner_tenant_id = mapped_column(StringUUID, nullable=False) - position: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) - is_pinned: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - last_used_at = mapped_column(db.DateTime, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) + is_pinned: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) + last_used_at = mapped_column(sa.DateTime, nullable=True) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def app(self): @@ -610,23 +610,23 @@ class InstalledApp(Base): class Conversation(Base): __tablename__ = "conversations" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="conversation_pkey"), - db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), + sa.PrimaryKeyConstraint("id", name="conversation_pkey"), + sa.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) app_model_config_id = mapped_column(StringUUID, nullable=True) model_provider = mapped_column(String(255), nullable=True) - override_model_configs = mapped_column(db.Text) + override_model_configs = mapped_column(sa.Text) model_id = mapped_column(String(255), nullable=True) mode: Mapped[str] = mapped_column(String(255)) name: Mapped[str] = mapped_column(String(255), nullable=False) - summary = mapped_column(db.Text) - _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) - introduction = mapped_column(db.Text) - system_instruction = mapped_column(db.Text) - system_instruction_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + summary = mapped_column(sa.Text) + _inputs: Mapped[dict] = mapped_column("inputs", sa.JSON) + introduction = mapped_column(sa.Text) + system_instruction = mapped_column(sa.Text) + system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) status: Mapped[str] = mapped_column(String(255), nullable=False) # The `invoke_from` records how the conversation is created. @@ -639,18 +639,18 @@ class Conversation(Base): from_source: Mapped[str] = mapped_column(String(255), nullable=False) from_end_user_id = mapped_column(StringUUID) from_account_id = mapped_column(StringUUID) - read_at = mapped_column(db.DateTime) + read_at = mapped_column(sa.DateTime) read_account_id = mapped_column(StringUUID) dialogue_count: Mapped[int] = mapped_column(default=0) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all") message_annotations = db.relationship( "MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all" ) - is_deleted: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) @property def inputs(self): @@ -892,36 +892,36 @@ class Message(Base): Index("message_created_at_idx", "created_at"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) model_provider = mapped_column(String(255), nullable=True) model_id = mapped_column(String(255), nullable=True) - override_model_configs = mapped_column(db.Text) - conversation_id = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=False) - _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) - query: Mapped[str] = mapped_column(db.Text, nullable=False) - message = mapped_column(db.JSON, nullable=False) - message_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) - message_unit_price = mapped_column(db.Numeric(10, 4), nullable=False) - message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - answer: Mapped[str] = db.Column(db.Text, nullable=False) # TODO make it mapped_column - answer_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) - answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False) - answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + override_model_configs = mapped_column(sa.Text) + conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False) + _inputs: Mapped[dict] = mapped_column("inputs", sa.JSON) + query: Mapped[str] = mapped_column(sa.Text, nullable=False) + message = mapped_column(sa.JSON, nullable=False) + message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) + message_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False) + message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) + answer: Mapped[str] = db.Column(sa.Text, nullable=False) # TODO make it mapped_column + answer_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) + answer_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False) + answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) parent_message_id = mapped_column(StringUUID, nullable=True) - provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0")) - total_price = mapped_column(db.Numeric(10, 7)) + provider_response_latency = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) + total_price = mapped_column(sa.Numeric(10, 7)) currency: Mapped[str] = mapped_column(String(255), nullable=False) - status = mapped_column(String(255), nullable=False, server_default=db.text("'normal'::character varying")) - error = mapped_column(db.Text) - message_metadata = mapped_column(db.Text) + status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying")) + error = mapped_column(sa.Text) + message_metadata = mapped_column(sa.Text) invoke_from: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) from_source: Mapped[str] = mapped_column(String(255), nullable=False) from_end_user_id: Mapped[Optional[str]] = mapped_column(StringUUID) from_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - agent_based: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) @property @@ -1228,23 +1228,23 @@ class Message(Base): class MessageFeedback(Base): __tablename__ = "message_feedbacks" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="message_feedback_pkey"), - db.Index("message_feedback_app_idx", "app_id"), - db.Index("message_feedback_message_idx", "message_id", "from_source"), - db.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"), + sa.PrimaryKeyConstraint("id", name="message_feedback_pkey"), + sa.Index("message_feedback_app_idx", "app_id"), + sa.Index("message_feedback_message_idx", "message_id", "from_source"), + sa.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) conversation_id = mapped_column(StringUUID, nullable=False) message_id = mapped_column(StringUUID, nullable=False) rating: Mapped[str] = mapped_column(String(255), nullable=False) - content = mapped_column(db.Text) + content = mapped_column(sa.Text) from_source: Mapped[str] = mapped_column(String(255), nullable=False) from_end_user_id = mapped_column(StringUUID) from_account_id = mapped_column(StringUUID) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def from_account(self): @@ -1270,9 +1270,9 @@ class MessageFeedback(Base): class MessageFile(Base): __tablename__ = "message_files" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="message_file_pkey"), - db.Index("message_file_message_idx", "message_id"), - db.Index("message_file_created_by_idx", "created_by"), + sa.PrimaryKeyConstraint("id", name="message_file_pkey"), + sa.Index("message_file_message_idx", "message_id"), + sa.Index("message_file_created_by_idx", "created_by"), ) def __init__( @@ -1296,37 +1296,37 @@ class MessageFile(Base): self.created_by_role = created_by_role.value self.created_by = created_by - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) type: Mapped[str] = mapped_column(String(255), nullable=False) transfer_method: Mapped[str] = mapped_column(String(255), nullable=False) - url: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) + url: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) belongs_to: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) upload_file_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class MessageAnnotation(Base): __tablename__ = "message_annotations" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="message_annotation_pkey"), - db.Index("message_annotation_app_idx", "app_id"), - db.Index("message_annotation_conversation_idx", "conversation_id"), - db.Index("message_annotation_message_idx", "message_id"), + sa.PrimaryKeyConstraint("id", name="message_annotation_pkey"), + sa.Index("message_annotation_app_idx", "app_id"), + sa.Index("message_annotation_conversation_idx", "conversation_id"), + sa.Index("message_annotation_message_idx", "message_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id: Mapped[str] = mapped_column(StringUUID) - conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, db.ForeignKey("conversations.id")) + conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, sa.ForeignKey("conversations.id")) message_id: Mapped[Optional[str]] = mapped_column(StringUUID) - question = db.Column(db.Text, nullable=True) - content = mapped_column(db.Text, nullable=False) - hit_count: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + question = db.Column(sa.Text, nullable=True) + content = mapped_column(sa.Text, nullable=False) + hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) account_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def account(self): @@ -1342,24 +1342,24 @@ class MessageAnnotation(Base): class AppAnnotationHitHistory(Base): __tablename__ = "app_annotation_hit_histories" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"), - db.Index("app_annotation_hit_histories_app_idx", "app_id"), - db.Index("app_annotation_hit_histories_account_idx", "account_id"), - db.Index("app_annotation_hit_histories_annotation_idx", "annotation_id"), - db.Index("app_annotation_hit_histories_message_idx", "message_id"), + sa.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"), + sa.Index("app_annotation_hit_histories_app_idx", "app_id"), + sa.Index("app_annotation_hit_histories_account_idx", "account_id"), + sa.Index("app_annotation_hit_histories_annotation_idx", "annotation_id"), + sa.Index("app_annotation_hit_histories_message_idx", "message_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - source = mapped_column(db.Text, nullable=False) - question = mapped_column(db.Text, nullable=False) + source = mapped_column(sa.Text, nullable=False) + question = mapped_column(sa.Text, nullable=False) account_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - score = mapped_column(Float, nullable=False, server_default=db.text("0")) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + score = mapped_column(Float, nullable=False, server_default=sa.text("0")) message_id = mapped_column(StringUUID, nullable=False) - annotation_question = mapped_column(db.Text, nullable=False) - annotation_content = mapped_column(db.Text, nullable=False) + annotation_question = mapped_column(sa.Text, nullable=False) + annotation_content = mapped_column(sa.Text, nullable=False) @property def account(self): @@ -1380,18 +1380,18 @@ class AppAnnotationHitHistory(Base): class AppAnnotationSetting(Base): __tablename__ = "app_annotation_settings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), - db.Index("app_annotation_settings_app_idx", "app_id"), + sa.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), + sa.Index("app_annotation_settings_app_idx", "app_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - score_threshold = mapped_column(Float, nullable=False, server_default=db.text("0")) + score_threshold = mapped_column(Float, nullable=False, server_default=sa.text("0")) collection_binding_id = mapped_column(StringUUID, nullable=False) created_user_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_user_id = mapped_column(StringUUID, nullable=False) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def collection_binding_detail(self): @@ -1408,58 +1408,58 @@ class AppAnnotationSetting(Base): class OperationLog(Base): __tablename__ = "operation_logs" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="operation_log_pkey"), - db.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"), + sa.PrimaryKeyConstraint("id", name="operation_log_pkey"), + sa.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) account_id = mapped_column(StringUUID, nullable=False) action: Mapped[str] = mapped_column(String(255), nullable=False) - content = mapped_column(db.JSON) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + content = mapped_column(sa.JSON) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) created_ip: Mapped[str] = mapped_column(String(255), nullable=False) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class EndUser(Base, UserMixin): __tablename__ = "end_users" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="end_user_pkey"), - db.Index("end_user_session_id_idx", "session_id", "type"), - db.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"), + sa.PrimaryKeyConstraint("id", name="end_user_pkey"), + sa.Index("end_user_session_id_idx", "session_id", "type"), + sa.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=True) type: Mapped[str] = mapped_column(String(255), nullable=False) external_user_id = mapped_column(String(255), nullable=True) name = mapped_column(String(255)) - is_anonymous: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + is_anonymous: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) session_id: Mapped[str] = mapped_column() - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class AppMCPServer(Base): __tablename__ = "app_mcp_servers" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="app_mcp_server_pkey"), - db.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"), - db.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"), + sa.PrimaryKeyConstraint("id", name="app_mcp_server_pkey"), + sa.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"), + sa.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) description: Mapped[str] = mapped_column(String(255), nullable=False) server_code: Mapped[str] = mapped_column(String(255), nullable=False) - status = mapped_column(String(255), nullable=False, server_default=db.text("'normal'::character varying")) - parameters = mapped_column(db.Text, nullable=False) + status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying")) + parameters = mapped_column(sa.Text, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @staticmethod def generate_server_code(n): @@ -1478,34 +1478,34 @@ class AppMCPServer(Base): class Site(Base): __tablename__ = "sites" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="site_pkey"), - db.Index("site_app_id_idx", "app_id"), - db.Index("site_code_idx", "code", "status"), + sa.PrimaryKeyConstraint("id", name="site_pkey"), + sa.Index("site_app_id_idx", "app_id"), + sa.Index("site_code_idx", "code", "status"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) title: Mapped[str] = mapped_column(String(255), nullable=False) icon_type = mapped_column(String(255), nullable=True) icon = mapped_column(String(255)) icon_background = mapped_column(String(255)) - description = mapped_column(db.Text) + description = mapped_column(sa.Text) default_language: Mapped[str] = mapped_column(String(255), nullable=False) chat_color_theme = mapped_column(String(255)) - chat_color_theme_inverted: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + chat_color_theme_inverted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) copyright = mapped_column(String(255)) privacy_policy = mapped_column(String(255)) - show_workflow_steps: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) - use_icon_as_answer_icon: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + show_workflow_steps: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) + use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) _custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", sa.TEXT, default="") customize_domain = mapped_column(String(255)) customize_token_strategy: Mapped[str] = mapped_column(String(255), nullable=False) - prompt_public: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - status = mapped_column(String(255), nullable=False, server_default=db.text("'normal'::character varying")) + prompt_public: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) + status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying")) created_by = mapped_column(StringUUID, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) code = mapped_column(String(255)) @property @@ -1535,19 +1535,19 @@ class Site(Base): class ApiToken(Base): __tablename__ = "api_tokens" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="api_token_pkey"), - db.Index("api_token_app_id_type_idx", "app_id", "type"), - db.Index("api_token_token_idx", "token", "type"), - db.Index("api_token_tenant_idx", "tenant_id", "type"), + sa.PrimaryKeyConstraint("id", name="api_token_pkey"), + sa.Index("api_token_app_id_type_idx", "app_id", "type"), + sa.Index("api_token_token_idx", "token", "type"), + sa.Index("api_token_tenant_idx", "tenant_id", "type"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=True) tenant_id = mapped_column(StringUUID, nullable=True) type = mapped_column(String(16), nullable=False) token: Mapped[str] = mapped_column(String(255), nullable=False) - last_used_at = mapped_column(db.DateTime, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + last_used_at = mapped_column(sa.DateTime, nullable=True) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @staticmethod def generate_api_key(prefix, n): @@ -1561,26 +1561,26 @@ class ApiToken(Base): class UploadFile(Base): __tablename__ = "upload_files" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="upload_file_pkey"), - db.Index("upload_file_tenant_idx", "tenant_id"), + sa.PrimaryKeyConstraint("id", name="upload_file_pkey"), + sa.Index("upload_file_tenant_idx", "tenant_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) storage_type: Mapped[str] = mapped_column(String(255), nullable=False) key: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) - size: Mapped[int] = mapped_column(db.Integer, nullable=False) + size: Mapped[int] = mapped_column(sa.Integer, nullable=False) extension: Mapped[str] = mapped_column(String(255), nullable=False) mime_type: Mapped[str] = mapped_column(String(255), nullable=True) created_by_role: Mapped[str] = mapped_column( - String(255), nullable=False, server_default=db.text("'account'::character varying") + String(255), nullable=False, server_default=sa.text("'account'::character varying") ) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - used: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + used: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - used_at: Mapped[datetime | None] = mapped_column(db.DateTime, nullable=True) + used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True) hash: Mapped[str | None] = mapped_column(String(255), nullable=True) source_url: Mapped[str] = mapped_column(sa.TEXT, default="") @@ -1623,71 +1623,71 @@ class UploadFile(Base): class ApiRequest(Base): __tablename__ = "api_requests" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="api_request_pkey"), - db.Index("api_request_token_idx", "tenant_id", "api_token_id"), + sa.PrimaryKeyConstraint("id", name="api_request_pkey"), + sa.Index("api_request_token_idx", "tenant_id", "api_token_id"), ) - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) api_token_id = mapped_column(StringUUID, nullable=False) path: Mapped[str] = mapped_column(String(255), nullable=False) - request = mapped_column(db.Text, nullable=True) - response = mapped_column(db.Text, nullable=True) + request = mapped_column(sa.Text, nullable=True) + response = mapped_column(sa.Text, nullable=True) ip: Mapped[str] = mapped_column(String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class MessageChain(Base): __tablename__ = "message_chains" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="message_chain_pkey"), - db.Index("message_chain_message_id_idx", "message_id"), + sa.PrimaryKeyConstraint("id", name="message_chain_pkey"), + sa.Index("message_chain_message_id_idx", "message_id"), ) - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) message_id = mapped_column(StringUUID, nullable=False) type: Mapped[str] = mapped_column(String(255), nullable=False) - input = mapped_column(db.Text, nullable=True) - output = mapped_column(db.Text, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + input = mapped_column(sa.Text, nullable=True) + output = mapped_column(sa.Text, nullable=True) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) class MessageAgentThought(Base): __tablename__ = "message_agent_thoughts" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"), - db.Index("message_agent_thought_message_id_idx", "message_id"), - db.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"), + sa.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"), + sa.Index("message_agent_thought_message_id_idx", "message_id"), + sa.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"), ) - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) message_id = mapped_column(StringUUID, nullable=False) message_chain_id = mapped_column(StringUUID, nullable=True) - position: Mapped[int] = mapped_column(db.Integer, nullable=False) - thought = mapped_column(db.Text, nullable=True) - tool = mapped_column(db.Text, nullable=True) - tool_labels_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text")) - tool_meta_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text")) - tool_input = mapped_column(db.Text, nullable=True) - observation = mapped_column(db.Text, nullable=True) + position: Mapped[int] = mapped_column(sa.Integer, nullable=False) + thought = mapped_column(sa.Text, nullable=True) + tool = mapped_column(sa.Text, nullable=True) + tool_labels_str = mapped_column(sa.Text, nullable=False, server_default=sa.text("'{}'::text")) + tool_meta_str = mapped_column(sa.Text, nullable=False, server_default=sa.text("'{}'::text")) + tool_input = mapped_column(sa.Text, nullable=True) + observation = mapped_column(sa.Text, nullable=True) # plugin_id = mapped_column(StringUUID, nullable=True) ## for future design - tool_process_data = mapped_column(db.Text, nullable=True) - message = mapped_column(db.Text, nullable=True) - message_token: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) - message_unit_price = mapped_column(db.Numeric, nullable=True) - message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - message_files = mapped_column(db.Text, nullable=True) - answer = db.Column(db.Text, nullable=True) - answer_token: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) - answer_unit_price = mapped_column(db.Numeric, nullable=True) - answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - tokens: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) - total_price = mapped_column(db.Numeric, nullable=True) + tool_process_data = mapped_column(sa.Text, nullable=True) + message = mapped_column(sa.Text, nullable=True) + message_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + message_unit_price = mapped_column(sa.Numeric, nullable=True) + message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) + message_files = mapped_column(sa.Text, nullable=True) + answer = db.Column(sa.Text, nullable=True) + answer_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + answer_unit_price = mapped_column(sa.Numeric, nullable=True) + answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) + tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + total_price = mapped_column(sa.Numeric, nullable=True) currency = mapped_column(String, nullable=True) - latency: Mapped[Optional[float]] = mapped_column(db.Float, nullable=True) + latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True) created_by_role = mapped_column(String, nullable=False) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) @property def files(self) -> list: @@ -1769,80 +1769,80 @@ class MessageAgentThought(Base): class DatasetRetrieverResource(Base): __tablename__ = "dataset_retriever_resources" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"), - db.Index("dataset_retriever_resource_message_id_idx", "message_id"), + sa.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"), + sa.Index("dataset_retriever_resource_message_id_idx", "message_id"), ) - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) message_id = mapped_column(StringUUID, nullable=False) - position: Mapped[int] = mapped_column(db.Integer, nullable=False) + position: Mapped[int] = mapped_column(sa.Integer, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - dataset_name = mapped_column(db.Text, nullable=False) + dataset_name = mapped_column(sa.Text, nullable=False) document_id = mapped_column(StringUUID, nullable=True) - document_name = mapped_column(db.Text, nullable=False) - data_source_type = mapped_column(db.Text, nullable=True) + document_name = mapped_column(sa.Text, nullable=False) + data_source_type = mapped_column(sa.Text, nullable=True) segment_id = mapped_column(StringUUID, nullable=True) - score: Mapped[Optional[float]] = mapped_column(db.Float, nullable=True) - content = mapped_column(db.Text, nullable=False) - hit_count: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) - word_count: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) - segment_position: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) - index_node_hash = mapped_column(db.Text, nullable=True) - retriever_from = mapped_column(db.Text, nullable=False) + score: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True) + content = mapped_column(sa.Text, nullable=False) + hit_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + segment_position: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + index_node_hash = mapped_column(sa.Text, nullable=True) + retriever_from = mapped_column(sa.Text, nullable=False) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) class Tag(Base): __tablename__ = "tags" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tag_pkey"), - db.Index("tag_type_idx", "type"), - db.Index("tag_name_idx", "name"), + sa.PrimaryKeyConstraint("id", name="tag_pkey"), + sa.Index("tag_type_idx", "type"), + sa.Index("tag_name_idx", "name"), ) TAG_TYPE_LIST = ["knowledge", "app"] - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=True) type = mapped_column(String(16), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class TagBinding(Base): __tablename__ = "tag_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tag_binding_pkey"), - db.Index("tag_bind_target_id_idx", "target_id"), - db.Index("tag_bind_tag_id_idx", "tag_id"), + sa.PrimaryKeyConstraint("id", name="tag_binding_pkey"), + sa.Index("tag_bind_target_id_idx", "target_id"), + sa.Index("tag_bind_tag_id_idx", "tag_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=True) tag_id = mapped_column(StringUUID, nullable=True) target_id = mapped_column(StringUUID, nullable=True) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class TraceAppConfig(Base): __tablename__ = "trace_app_config" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"), - db.Index("trace_app_config_app_id_idx", "app_id"), + sa.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"), + sa.Index("trace_app_config_app_id_idx", "app_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) tracing_provider = mapped_column(String(255), nullable=True) - tracing_config = mapped_column(db.JSON, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + tracing_config = mapped_column(sa.JSON, nullable=True) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column( - db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) - is_active: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) @property def tracing_config_dict(self): diff --git a/api/models/provider.py b/api/models/provider.py index 7bfc249b0..4ea2c59fd 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -2,11 +2,11 @@ from datetime import datetime from enum import Enum from typing import Optional +import sqlalchemy as sa from sqlalchemy import DateTime, String, func, text from sqlalchemy.orm import Mapped, mapped_column from .base import Base -from .engine import db from .types import StringUUID @@ -47,9 +47,9 @@ class Provider(Base): __tablename__ = "providers" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="provider_pkey"), - db.Index("provider_tenant_id_provider_idx", "tenant_id", "provider_name"), - db.UniqueConstraint( + sa.PrimaryKeyConstraint("id", name="provider_pkey"), + sa.Index("provider_tenant_id_provider_idx", "tenant_id", "provider_name"), + sa.UniqueConstraint( "tenant_id", "provider_name", "provider_type", "quota_type", name="unique_provider_name_type_quota" ), ) @@ -60,15 +60,15 @@ class Provider(Base): provider_type: Mapped[str] = mapped_column( String(40), nullable=False, server_default=text("'custom'::character varying") ) - encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) - is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) + encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) + is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) quota_type: Mapped[Optional[str]] = mapped_column( String(40), nullable=True, server_default=text("''::character varying") ) - quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True) - quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0) + quota_limit: Mapped[Optional[int]] = mapped_column(sa.BigInteger, nullable=True) + quota_used: Mapped[Optional[int]] = mapped_column(sa.BigInteger, default=0) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -104,9 +104,9 @@ class ProviderModel(Base): __tablename__ = "provider_models" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="provider_model_pkey"), - db.Index("provider_model_tenant_id_provider_idx", "tenant_id", "provider_name"), - db.UniqueConstraint( + sa.PrimaryKeyConstraint("id", name="provider_model_pkey"), + sa.Index("provider_model_tenant_id_provider_idx", "tenant_id", "provider_name"), + sa.UniqueConstraint( "tenant_id", "provider_name", "model_name", "model_type", name="unique_provider_model_name" ), ) @@ -116,8 +116,8 @@ class ProviderModel(Base): provider_name: Mapped[str] = mapped_column(String(255), nullable=False) model_name: Mapped[str] = mapped_column(String(255), nullable=False) model_type: Mapped[str] = mapped_column(String(40), nullable=False) - encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) - is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) + encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) + is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -125,8 +125,8 @@ class ProviderModel(Base): class TenantDefaultModel(Base): __tablename__ = "tenant_default_models" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"), - db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), + sa.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"), + sa.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) @@ -141,8 +141,8 @@ class TenantDefaultModel(Base): class TenantPreferredModelProvider(Base): __tablename__ = "tenant_preferred_model_providers" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"), - db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), + sa.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"), + sa.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) @@ -156,8 +156,8 @@ class TenantPreferredModelProvider(Base): class ProviderOrder(Base): __tablename__ = "provider_orders" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="provider_order_pkey"), - db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), + sa.PrimaryKeyConstraint("id", name="provider_order_pkey"), + sa.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) @@ -167,9 +167,9 @@ class ProviderOrder(Base): payment_product_id: Mapped[str] = mapped_column(String(191), nullable=False) payment_id: Mapped[Optional[str]] = mapped_column(String(191)) transaction_id: Mapped[Optional[str]] = mapped_column(String(191)) - quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1")) + quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1")) currency: Mapped[Optional[str]] = mapped_column(String(40)) - total_amount: Mapped[Optional[int]] = mapped_column(db.Integer) + total_amount: Mapped[Optional[int]] = mapped_column(sa.Integer) payment_status: Mapped[str] = mapped_column( String(40), nullable=False, server_default=text("'wait_pay'::character varying") ) @@ -187,8 +187,8 @@ class ProviderModelSetting(Base): __tablename__ = "provider_model_settings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="provider_model_setting_pkey"), - db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), + sa.PrimaryKeyConstraint("id", name="provider_model_setting_pkey"), + sa.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) @@ -196,8 +196,8 @@ class ProviderModelSetting(Base): provider_name: Mapped[str] = mapped_column(String(255), nullable=False) model_name: Mapped[str] = mapped_column(String(255), nullable=False) model_type: Mapped[str] = mapped_column(String(40), nullable=False) - enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) - load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true")) + load_balancing_enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -209,8 +209,8 @@ class LoadBalancingModelConfig(Base): __tablename__ = "load_balancing_model_configs" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="load_balancing_model_config_pkey"), - db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), + sa.PrimaryKeyConstraint("id", name="load_balancing_model_config_pkey"), + sa.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) @@ -219,7 +219,7 @@ class LoadBalancingModelConfig(Base): model_name: Mapped[str] = mapped_column(String(255), nullable=False) model_type: Mapped[str] = mapped_column(String(40), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) - encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) - enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) + encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true")) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/source.py b/api/models/source.py index 8191c874a..8456d65a8 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -2,50 +2,50 @@ import json from datetime import datetime from typing import Optional +import sqlalchemy as sa from sqlalchemy import DateTime, String, func from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column from models.base import Base -from .engine import db from .types import StringUUID class DataSourceOauthBinding(Base): __tablename__ = "data_source_oauth_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="source_binding_pkey"), - db.Index("source_binding_tenant_id_idx", "tenant_id"), - db.Index("source_info_idx", "source_info", postgresql_using="gin"), + sa.PrimaryKeyConstraint("id", name="source_binding_pkey"), + sa.Index("source_binding_tenant_id_idx", "tenant_id"), + sa.Index("source_info_idx", "source_info", postgresql_using="gin"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) access_token: Mapped[str] = mapped_column(String(255), nullable=False) provider: Mapped[str] = mapped_column(String(255), nullable=False) source_info = mapped_column(JSONB, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - disabled: Mapped[Optional[bool]] = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) + disabled: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) class DataSourceApiKeyAuthBinding(Base): __tablename__ = "data_source_api_key_auth_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"), - db.Index("data_source_api_key_auth_binding_tenant_id_idx", "tenant_id"), - db.Index("data_source_api_key_auth_binding_provider_idx", "provider"), + sa.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"), + sa.Index("data_source_api_key_auth_binding_tenant_id_idx", "tenant_id"), + sa.Index("data_source_api_key_auth_binding_provider_idx", "provider"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) category: Mapped[str] = mapped_column(String(255), nullable=False) provider: Mapped[str] = mapped_column(String(255), nullable=False) - credentials = mapped_column(db.Text, nullable=True) # JSON + credentials = mapped_column(sa.Text, nullable=True) # JSON created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) - disabled: Mapped[Optional[bool]] = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) + disabled: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) def to_dict(self): return { diff --git a/api/models/task.py b/api/models/task.py index 66a47ea4d..ab700c553 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -1,6 +1,7 @@ from datetime import datetime from typing import Optional +import sqlalchemy as sa from celery import states # type: ignore from sqlalchemy import DateTime, String from sqlalchemy.orm import Mapped, mapped_column @@ -16,7 +17,7 @@ class CeleryTask(Base): __tablename__ = "celery_taskmeta" - id = mapped_column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) + id = mapped_column(sa.Integer, sa.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) task_id = mapped_column(String(155), unique=True) status = mapped_column(String(50), default=states.PENDING) result = mapped_column(db.PickleType, nullable=True) @@ -26,12 +27,12 @@ class CeleryTask(Base): onupdate=lambda: naive_utc_now(), nullable=True, ) - traceback = mapped_column(db.Text, nullable=True) + traceback = mapped_column(sa.Text, nullable=True) name = mapped_column(String(155), nullable=True) - args = mapped_column(db.LargeBinary, nullable=True) - kwargs = mapped_column(db.LargeBinary, nullable=True) + args = mapped_column(sa.LargeBinary, nullable=True) + kwargs = mapped_column(sa.LargeBinary, nullable=True) worker = mapped_column(String(155), nullable=True) - retries: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) + retries: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) queue = mapped_column(String(155), nullable=True) @@ -41,7 +42,7 @@ class CeleryTaskSet(Base): __tablename__ = "celery_tasksetmeta" id: Mapped[int] = mapped_column( - db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True + sa.Integer, sa.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True ) taskset_id = mapped_column(String(155), unique=True) result = mapped_column(db.PickleType, nullable=True) diff --git a/api/models/tools.py b/api/models/tools.py index 1491cd90c..408c1371c 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -25,33 +25,33 @@ from .types import StringUUID class ToolOAuthSystemClient(Base): __tablename__ = "tool_oauth_system_clients" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"), - db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"), + sa.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"), + sa.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) plugin_id = mapped_column(String(512), nullable=False) provider: Mapped[str] = mapped_column(String(255), nullable=False) # oauth params of the tool provider - encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) + encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False) # tenant level tool oauth client params (client_id, client_secret, etc.) class ToolOAuthTenantClient(Base): __tablename__ = "tool_oauth_tenant_clients" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"), - db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"), + sa.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"), + sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # tenant id tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) plugin_id: Mapped[str] = mapped_column(String(512), nullable=False) provider: Mapped[str] = mapped_column(String(255), nullable=False) - enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) # oauth params of the tool provider - encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) + encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False) @property def oauth_params(self) -> dict: @@ -65,14 +65,14 @@ class BuiltinToolProvider(Base): __tablename__ = "tool_builtin_providers" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), - db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"), + sa.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), + sa.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"), ) # id of the tool provider - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) name: Mapped[str] = mapped_column( - String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying") + String(256), nullable=False, server_default=sa.text("'API KEY 1'::character varying") ) # id of the tenant tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True) @@ -81,19 +81,19 @@ class BuiltinToolProvider(Base): # name of the tool provider provider: Mapped[str] = mapped_column(String(256), nullable=False) # credential of the tool provider - encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True) + encrypted_credentials: Mapped[str] = mapped_column(sa.Text, nullable=True) created_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) updated_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) - is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) # credential type, e.g., "api-key", "oauth2" credential_type: Mapped[str] = mapped_column( - String(32), nullable=False, server_default=db.text("'api-key'::character varying") + String(32), nullable=False, server_default=sa.text("'api-key'::character varying") ) - expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1")) + expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1")) @property def credentials(self) -> dict: @@ -107,28 +107,28 @@ class ApiToolProvider(Base): __tablename__ = "tool_api_providers" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"), - db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"), + sa.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"), + sa.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # name of the api provider - name = mapped_column(String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying")) + name = mapped_column(String(255), nullable=False, server_default=sa.text("'API KEY 1'::character varying")) # icon icon: Mapped[str] = mapped_column(String(255), nullable=False) # original schema - schema = mapped_column(db.Text, nullable=False) + schema = mapped_column(sa.Text, nullable=False) schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False) # who created this tool user_id = mapped_column(StringUUID, nullable=False) # tenant id tenant_id = mapped_column(StringUUID, nullable=False) # description of the provider - description = mapped_column(db.Text, nullable=False) + description = mapped_column(sa.Text, nullable=False) # json format tools - tools_str = mapped_column(db.Text, nullable=False) + tools_str = mapped_column(sa.Text, nullable=False) # json format credentials - credentials_str = mapped_column(db.Text, nullable=False) + credentials_str = mapped_column(sa.Text, nullable=False) # privacy policy privacy_policy = mapped_column(String(255), nullable=True) # custom_disclaimer @@ -167,11 +167,11 @@ class ToolLabelBinding(Base): __tablename__ = "tool_label_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"), - db.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"), + sa.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"), + sa.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # tool id tool_id: Mapped[str] = mapped_column(String(64), nullable=False) # tool type @@ -187,12 +187,12 @@ class WorkflowToolProvider(Base): __tablename__ = "tool_workflow_providers" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"), - db.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"), - db.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"), + sa.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"), + sa.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"), + sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # name of the workflow provider name: Mapped[str] = mapped_column(String(255), nullable=False) # label of the workflow provider @@ -208,17 +208,17 @@ class WorkflowToolProvider(Base): # tenant id tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # description of the provider - description: Mapped[str] = mapped_column(db.Text, nullable=False) + description: Mapped[str] = mapped_column(sa.Text, nullable=False) # parameter configuration - parameter_configuration: Mapped[str] = mapped_column(db.Text, nullable=False, server_default="[]") + parameter_configuration: Mapped[str] = mapped_column(sa.Text, nullable=False, server_default="[]") # privacy policy privacy_policy: Mapped[str] = mapped_column(String(255), nullable=True, server_default="") created_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) updated_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) @property @@ -245,19 +245,19 @@ class MCPToolProvider(Base): __tablename__ = "tool_mcp_providers" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"), - db.UniqueConstraint("tenant_id", "server_url_hash", name="unique_mcp_provider_server_url"), - db.UniqueConstraint("tenant_id", "name", name="unique_mcp_provider_name"), - db.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"), + sa.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"), + sa.UniqueConstraint("tenant_id", "server_url_hash", name="unique_mcp_provider_server_url"), + sa.UniqueConstraint("tenant_id", "name", name="unique_mcp_provider_name"), + sa.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # name of the mcp provider name: Mapped[str] = mapped_column(String(40), nullable=False) # server identifier of the mcp provider server_identifier: Mapped[str] = mapped_column(String(64), nullable=False) # encrypted url of the mcp provider - server_url: Mapped[str] = mapped_column(db.Text, nullable=False) + server_url: Mapped[str] = mapped_column(sa.Text, nullable=False) # hash of server_url for uniqueness check server_url_hash: Mapped[str] = mapped_column(String(64), nullable=False) # icon of the mcp provider @@ -267,16 +267,16 @@ class MCPToolProvider(Base): # who created this tool user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # encrypted credentials - encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True) + encrypted_credentials: Mapped[str] = mapped_column(sa.Text, nullable=True) # authed - authed: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=False) + authed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False) # tools - tools: Mapped[str] = mapped_column(db.Text, nullable=False, default="[]") + tools: Mapped[str] = mapped_column(sa.Text, nullable=False, default="[]") created_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) updated_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) def load_user(self) -> Account | None: @@ -347,9 +347,9 @@ class ToolModelInvoke(Base): """ __tablename__ = "tool_model_invokes" - __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),) + __table_args__ = (sa.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # who invoke this tool user_id = mapped_column(StringUUID, nullable=False) # tenant id @@ -361,18 +361,18 @@ class ToolModelInvoke(Base): # tool name tool_name = mapped_column(String(128), nullable=False) # invoke parameters - model_parameters = mapped_column(db.Text, nullable=False) + model_parameters = mapped_column(sa.Text, nullable=False) # prompt messages - prompt_messages = mapped_column(db.Text, nullable=False) + prompt_messages = mapped_column(sa.Text, nullable=False) # invoke response - model_response = mapped_column(db.Text, nullable=False) + model_response = mapped_column(sa.Text, nullable=False) - prompt_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) - answer_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) - answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False) - answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0")) - total_price = mapped_column(db.Numeric(10, 7)) + prompt_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) + answer_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) + answer_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False) + answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) + provider_response_latency = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) + total_price = mapped_column(sa.Numeric(10, 7)) currency: Mapped[str] = mapped_column(String(255), nullable=False) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -386,13 +386,13 @@ class ToolConversationVariables(Base): __tablename__ = "tool_conversation_variables" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"), + sa.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"), # add index for user_id and conversation_id - db.Index("user_id_idx", "user_id"), - db.Index("conversation_id_idx", "conversation_id"), + sa.Index("user_id_idx", "user_id"), + sa.Index("conversation_id_idx", "conversation_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # conversation user id user_id = mapped_column(StringUUID, nullable=False) # tenant id @@ -400,7 +400,7 @@ class ToolConversationVariables(Base): # conversation id conversation_id = mapped_column(StringUUID, nullable=False) # variables pool - variables_str = mapped_column(db.Text, nullable=False) + variables_str = mapped_column(sa.Text, nullable=False) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -417,11 +417,11 @@ class ToolFile(Base): __tablename__ = "tool_files" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_file_pkey"), - db.Index("tool_file_conversation_id_idx", "conversation_id"), + sa.PrimaryKeyConstraint("id", name="tool_file_pkey"), + sa.Index("tool_file_conversation_id_idx", "conversation_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # conversation user id user_id: Mapped[str] = mapped_column(StringUUID) # tenant id @@ -448,30 +448,30 @@ class DeprecatedPublishedAppTool(Base): __tablename__ = "tool_published_apps" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="published_app_tool_pkey"), - db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"), + sa.PrimaryKeyConstraint("id", name="published_app_tool_pkey"), + sa.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # id of the app app_id = mapped_column(StringUUID, ForeignKey("apps.id"), nullable=False) user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # who published this tool - description = mapped_column(db.Text, nullable=False) + description = mapped_column(sa.Text, nullable=False) # llm_description of the tool, for LLM - llm_description = mapped_column(db.Text, nullable=False) + llm_description = mapped_column(sa.Text, nullable=False) # query description, query will be seem as a parameter of the tool, # to describe this parameter to llm, we need this field - query_description = mapped_column(db.Text, nullable=False) + query_description = mapped_column(sa.Text, nullable=False) # query name, the name of the query parameter query_name = mapped_column(String(40), nullable=False) # name of the tool provider tool_name = mapped_column(String(40), nullable=False) # author author = mapped_column(String(40), nullable=False) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = mapped_column(sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")) @property def description_i18n(self) -> I18nObject: diff --git a/api/models/web.py b/api/models/web.py index 1bf9b5c76..74f99e187 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,5 +1,6 @@ from datetime import datetime +import sqlalchemy as sa from sqlalchemy import DateTime, String, func from sqlalchemy.orm import Mapped, mapped_column @@ -13,15 +14,15 @@ from .types import StringUUID class SavedMessage(Base): __tablename__ = "saved_messages" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="saved_message_pkey"), - db.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"), + sa.PrimaryKeyConstraint("id", name="saved_message_pkey"), + sa.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) message_id = mapped_column(StringUUID, nullable=False) created_by_role = mapped_column( - String(255), nullable=False, server_default=db.text("'end_user'::character varying") + String(255), nullable=False, server_default=sa.text("'end_user'::character varying") ) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -34,15 +35,15 @@ class SavedMessage(Base): class PinnedConversation(Base): __tablename__ = "pinned_conversations" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"), - db.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"), + sa.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"), + sa.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) conversation_id: Mapped[str] = mapped_column(StringUUID) created_by_role = mapped_column( - String(255), nullable=False, server_default=db.text("'end_user'::character varying") + String(255), nullable=False, server_default=sa.text("'end_user'::character varying") ) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/workflow.py b/api/models/workflow.py index 6c7d061bb..9cf6a0045 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -6,6 +6,7 @@ from enum import Enum, StrEnum from typing import TYPE_CHECKING, Any, Optional, Union from uuid import uuid4 +import sqlalchemy as sa from flask_login import current_user from sqlalchemy import DateTime, orm @@ -24,7 +25,6 @@ from ._workflow_exc import NodeNotFoundError, WorkflowDataError if TYPE_CHECKING: from models.model import AppMode -import sqlalchemy as sa from sqlalchemy import Index, PrimaryKeyConstraint, String, UniqueConstraint, func from sqlalchemy.orm import Mapped, declared_attr, mapped_column @@ -117,11 +117,11 @@ class Workflow(Base): __tablename__ = "workflows" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="workflow_pkey"), - db.Index("workflow_version_idx", "tenant_id", "app_id", "version"), + sa.PrimaryKeyConstraint("id", name="workflow_pkey"), + sa.Index("workflow_version_idx", "tenant_id", "app_id", "version"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) type: Mapped[str] = mapped_column(String(255), nullable=False) @@ -140,10 +140,10 @@ class Workflow(Base): server_onupdate=func.current_timestamp(), ) _environment_variables: Mapped[str] = mapped_column( - "environment_variables", db.Text, nullable=False, server_default="{}" + "environment_variables", sa.Text, nullable=False, server_default="{}" ) _conversation_variables: Mapped[str] = mapped_column( - "conversation_variables", db.Text, nullable=False, server_default="{}" + "conversation_variables", sa.Text, nullable=False, server_default="{}" ) VERSION_DRAFT = "draft" @@ -491,11 +491,11 @@ class WorkflowRun(Base): __tablename__ = "workflow_runs" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="workflow_run_pkey"), - db.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"), + sa.PrimaryKeyConstraint("id", name="workflow_run_pkey"), + sa.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) @@ -503,19 +503,19 @@ class WorkflowRun(Base): type: Mapped[str] = mapped_column(String(255)) triggered_from: Mapped[str] = mapped_column(String(255)) version: Mapped[str] = mapped_column(String(255)) - graph: Mapped[Optional[str]] = mapped_column(db.Text) - inputs: Mapped[Optional[str]] = mapped_column(db.Text) + graph: Mapped[Optional[str]] = mapped_column(sa.Text) + inputs: Mapped[Optional[str]] = mapped_column(sa.Text) status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") - error: Mapped[Optional[str]] = mapped_column(db.Text) - elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0")) + error: Mapped[Optional[str]] = mapped_column(sa.Text) + elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) - total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True) + total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) created_by_role: Mapped[str] = mapped_column(String(255)) # account, end_user created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime) - exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True) + exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) @property def created_by_account(self): @@ -704,25 +704,25 @@ class WorkflowNodeExecutionModel(Base): ), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) triggered_from: Mapped[str] = mapped_column(String(255)) workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) - index: Mapped[int] = mapped_column(db.Integer) + index: Mapped[int] = mapped_column(sa.Integer) predecessor_node_id: Mapped[Optional[str]] = mapped_column(String(255)) node_execution_id: Mapped[Optional[str]] = mapped_column(String(255)) node_id: Mapped[str] = mapped_column(String(255)) node_type: Mapped[str] = mapped_column(String(255)) title: Mapped[str] = mapped_column(String(255)) - inputs: Mapped[Optional[str]] = mapped_column(db.Text) - process_data: Mapped[Optional[str]] = mapped_column(db.Text) - outputs: Mapped[Optional[str]] = mapped_column(db.Text) + inputs: Mapped[Optional[str]] = mapped_column(sa.Text) + process_data: Mapped[Optional[str]] = mapped_column(sa.Text) + outputs: Mapped[Optional[str]] = mapped_column(sa.Text) status: Mapped[str] = mapped_column(String(255)) - error: Mapped[Optional[str]] = mapped_column(db.Text) - elapsed_time: Mapped[float] = mapped_column(db.Float, server_default=db.text("0")) - execution_metadata: Mapped[Optional[str]] = mapped_column(db.Text) + error: Mapped[Optional[str]] = mapped_column(sa.Text) + elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0")) + execution_metadata: Mapped[Optional[str]] = mapped_column(sa.Text) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) created_by_role: Mapped[str] = mapped_column(String(255)) created_by: Mapped[str] = mapped_column(StringUUID) @@ -834,11 +834,11 @@ class WorkflowAppLog(Base): __tablename__ = "workflow_app_logs" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"), - db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), + sa.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"), + sa.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) @@ -871,7 +871,7 @@ class ConversationVariable(Base): id: Mapped[str] = mapped_column(StringUUID, primary_key=True) conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True, index=True) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True) - data: Mapped[str] = mapped_column(db.Text, nullable=False) + data: Mapped[str] = mapped_column(sa.Text, nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), index=True ) @@ -933,7 +933,7 @@ class WorkflowDraftVariable(Base): __allow_unmapped__ = True # id is the unique identifier of a draft variable. - id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) created_at: Mapped[datetime] = mapped_column( DateTime, diff --git a/api/services/plugin/data_migration.py b/api/services/plugin/data_migration.py index 7a4f886bf..c5ad65ec8 100644 --- a/api/services/plugin/data_migration.py +++ b/api/services/plugin/data_migration.py @@ -2,6 +2,7 @@ import json import logging import click +import sqlalchemy as sa from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID from models.engine import db @@ -38,7 +39,7 @@ class PluginDataMigration: where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != '' limit 1000""" with db.engine.begin() as conn: - rs = conn.execute(db.text(sql)) + rs = conn.execute(sa.text(sql)) current_iter_count = 0 for i in rs: @@ -94,7 +95,7 @@ limit 1000""" :provider_name {update_retrieval_model_sql} where id = :record_id""" - conn.execute(db.text(sql), params) + conn.execute(sa.text(sql), params) click.echo( click.style( f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})", @@ -148,7 +149,7 @@ limit 1000""" params = {"last_id": last_id or ""} with db.engine.begin() as conn: - rs = conn.execute(db.text(sql), params) + rs = conn.execute(sa.text(sql), params) current_iter_count = 0 batch_updates = [] @@ -193,7 +194,7 @@ limit 1000""" SET {provider_column_name} = :updated_value WHERE id = :record_id """ - conn.execute(db.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates]) + conn.execute(sa.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates]) click.echo( click.style( f"[{processed_count}] Batch migrated [{len(batch_updates)}] records from [{table_name}]", diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index 222d70a31..221069b2b 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -9,6 +9,7 @@ from typing import Any, Optional from uuid import uuid4 import click +import sqlalchemy as sa import tqdm from flask import Flask, current_app from sqlalchemy.orm import Session @@ -197,7 +198,7 @@ class PluginMigration: """ with Session(db.engine) as session: rs = session.execute( - db.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id} + sa.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id} ) result = [] for row in rs: diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index b6f772dd6..929b60e52 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -3,6 +3,7 @@ import time from collections.abc import Callable import click +import sqlalchemy as sa from celery import shared_task # type: ignore from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError @@ -331,7 +332,7 @@ def _delete_trace_app_configs(tenant_id: str, app_id: str): def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None: while True: with db.engine.begin() as conn: - rs = conn.execute(db.text(query_sql), params) + rs = conn.execute(sa.text(query_sql), params) if rs.rowcount == 0: break