From 6d3e198c3cf38328fa4672f12293b241a8dfc02c Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Wed, 23 Jul 2025 01:39:59 +0900 Subject: [PATCH] Mapped column (#22644) Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- api/core/indexing_runner.py | 3 +- .../datasource/vdb/pgvecto_rs/pgvecto_rs.py | 4 +- api/core/rag/extractor/notion_extractor.py | 5 +- api/models/account.py | 105 +-- api/models/api_based_extension.py | 13 +- api/models/dataset.py | 403 +++++------ api/models/model.py | 642 +++++++++--------- api/models/source.py | 33 +- api/models/task.py | 38 +- api/models/tools.py | 94 +-- api/models/web.py | 26 +- api/schedule/clean_unused_datasets_task.py | 8 +- api/services/account_service.py | 2 +- api/services/app_service.py | 23 +- api/services/billing_service.py | 2 +- api/services/dataset_service.py | 14 +- api/tasks/create_segment_to_index_task.py | 22 +- .../unit_tests/models/test_types_enum_text.py | 22 +- .../position_helper/test_position_helper.py | 2 +- 19 files changed, 745 insertions(+), 716 deletions(-) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 305a9190d..e5976f4c9 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -672,8 +672,7 @@ class IndexingRunner: if extra_update_params: update_params.update(extra_update_params) - - db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params) + db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params) # type: ignore db.session.commit() @staticmethod diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index 46aefef11..b0f0eeca3 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -6,7 +6,7 @@ from uuid import UUID, uuid4 from numpy import ndarray from pgvecto_rs.sqlalchemy import VECTOR # type: ignore from pydantic import BaseModel, model_validator -from sqlalchemy import Float, String, create_engine, insert, select, text +from sqlalchemy import Float, create_engine, insert, select, text from sqlalchemy import text as sql_text from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Mapped, Session, mapped_column @@ -67,7 +67,7 @@ class PGVectoRS(BaseVector): postgresql.UUID(as_uuid=True), primary_key=True, ) - text: Mapped[str] = mapped_column(String) + text: Mapped[str] meta: Mapped[dict] = mapped_column(postgresql.JSONB) vector: Mapped[ndarray] = mapped_column(VECTOR(dim)) diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index eca955ddd..81a0810e2 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -331,9 +331,10 @@ class NotionExtractor(BaseExtractor): last_edited_time = self.get_notion_last_edited_time() data_source_info = document_model.data_source_info_dict data_source_info["last_edited_time"] = last_edited_time - update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)} - db.session.query(DocumentModel).filter_by(id=document_model.id).update(update_params) + db.session.query(DocumentModel).filter_by(id=document_model.id).update( + {DocumentModel.data_source_info: json.dumps(data_source_info)} + ) # type: ignore db.session.commit() def get_notion_last_edited_time(self) -> str: diff --git a/api/models/account.py b/api/models/account.py index 1af571bc0..01bf9d5f7 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -1,5 +1,6 @@ import enum import json +from datetime import datetime from typing import Optional, cast from flask_login import UserMixin # type: ignore @@ -85,21 +86,23 @@ class Account(UserMixin, Base): __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - name = db.Column(db.String(255), nullable=False) - email = db.Column(db.String(255), nullable=False) - password = db.Column(db.String(255), nullable=True) - password_salt = db.Column(db.String(255), nullable=True) - avatar = db.Column(db.String(255)) - interface_language = db.Column(db.String(255)) - interface_theme = db.Column(db.String(255)) - timezone = db.Column(db.String(255)) - last_login_at = db.Column(db.DateTime) - last_login_ip = db.Column(db.String(255)) - last_active_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - status = db.Column(db.String(16), nullable=False, server_default=db.text("'active'::character varying")) - initialized_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + name: Mapped[str] = mapped_column(db.String(255)) + email: Mapped[str] = mapped_column(db.String(255)) + password: Mapped[Optional[str]] = mapped_column(db.String(255)) + password_salt: Mapped[Optional[str]] = mapped_column(db.String(255)) + avatar: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) + interface_language: Mapped[Optional[str]] = mapped_column(db.String(255)) + interface_theme: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) + timezone: Mapped[Optional[str]] = mapped_column(db.String(255)) + last_login_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) + last_login_ip: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) + last_active_at: Mapped[datetime] = mapped_column( + db.DateTime, server_default=func.current_timestamp(), nullable=False + ) + status: Mapped[str] = mapped_column(db.String(16), server_default=db.text("'active'::character varying")) + initialized_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) + created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False) @reconstructor def init_on_load(self): @@ -143,7 +146,7 @@ class Account(UserMixin, Base): return tenant, join = tenant_account_join - self.role = join.role + self.role = TenantAccountRole(join.role) self._current_tenant = tenant @property @@ -196,14 +199,14 @@ class Tenant(Base): __tablename__ = "tenants" __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) - id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - name = db.Column(db.String(255), nullable=False) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + name: Mapped[str] = mapped_column(db.String(255)) encrypt_public_key = db.Column(db.Text) - plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying")) - status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) - custom_config = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + plan: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'basic'::character varying")) + status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'normal'::character varying")) + custom_config: Mapped[Optional[str]] = mapped_column(db.Text) + created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) def get_accounts(self) -> list[Account]: return ( @@ -230,14 +233,14 @@ class TenantAccountJoin(Base): db.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - account_id = db.Column(StringUUID, nullable=False) - current = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - role = db.Column(db.String(16), nullable=False, server_default="normal") - invited_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + account_id: Mapped[str] = mapped_column(StringUUID) + current: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) + role: Mapped[str] = mapped_column(db.String(16), server_default="normal") + invited_by: Mapped[Optional[str]] = mapped_column(StringUUID) + created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) class AccountIntegrate(Base): @@ -248,13 +251,13 @@ class AccountIntegrate(Base): db.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - account_id = db.Column(StringUUID, nullable=False) - provider = db.Column(db.String(16), nullable=False) - open_id = db.Column(db.String(255), nullable=False) - encrypted_token = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + account_id: Mapped[str] = mapped_column(StringUUID) + provider: Mapped[str] = mapped_column(db.String(16)) + open_id: Mapped[str] = mapped_column(db.String(255)) + encrypted_token: Mapped[str] = mapped_column(db.String(255)) + created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) class InvitationCode(Base): @@ -265,15 +268,15 @@ class InvitationCode(Base): db.Index("invitation_codes_code_idx", "code", "status"), ) - id = db.Column(db.Integer, nullable=False) - batch = db.Column(db.String(255), nullable=False) - code = db.Column(db.String(32), nullable=False) - status = db.Column(db.String(16), nullable=False, server_default=db.text("'unused'::character varying")) - used_at = db.Column(db.DateTime) - used_by_tenant_id = db.Column(StringUUID) - used_by_account_id = db.Column(StringUUID) - deprecated_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + id: Mapped[int] = mapped_column(db.Integer) + batch: Mapped[str] = mapped_column(db.String(255)) + code: Mapped[str] = mapped_column(db.String(32)) + status: Mapped[str] = mapped_column(db.String(16), server_default=db.text("'unused'::character varying")) + used_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + used_by_tenant_id: Mapped[Optional[str]] = mapped_column(StringUUID) + used_by_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) + deprecated_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) + created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=db.text("CURRENT_TIMESTAMP(0)")) class TenantPluginPermission(Base): @@ -294,8 +297,6 @@ class TenantPluginPermission(Base): ) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - install_permission: Mapped[InstallPermission] = mapped_column( - db.String(16), nullable=False, server_default="everyone" - ) - debug_permission: Mapped[DebugPermission] = mapped_column(db.String(16), nullable=False, server_default="noone") + tenant_id: Mapped[str] = mapped_column(StringUUID) + install_permission: Mapped[InstallPermission] = mapped_column(db.String(16), server_default="everyone") + debug_permission: Mapped[DebugPermission] = mapped_column(db.String(16), server_default="noone") diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 5a70e1862..3cef5a0fb 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -1,6 +1,7 @@ import enum from sqlalchemy import func +from sqlalchemy.orm import mapped_column from .base import Base from .engine import db @@ -21,9 +22,9 @@ class APIBasedExtension(Base): db.Index("api_based_extension_tenant_idx", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - name = db.Column(db.String(255), nullable=False) - api_endpoint = db.Column(db.String(255), nullable=False) - api_key = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + name = mapped_column(db.String(255), nullable=False) + api_endpoint = mapped_column(db.String(255), nullable=False) + api_key = mapped_column(db.Text, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/dataset.py b/api/models/dataset.py index 57e54b72a..d5a13efb9 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -8,12 +8,13 @@ import os import pickle import re import time +from datetime import datetime from json import JSONDecodeError -from typing import Any, cast +from typing import Any, Optional, cast from sqlalchemy import func from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import Mapped +from sqlalchemy.orm import Mapped, mapped_column from configs import dify_config from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource @@ -45,24 +46,24 @@ class Dataset(Base): INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None] PROVIDER_LIST = ["vendor", "external", None] - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - name = db.Column(db.String(255), nullable=False) - description = db.Column(db.Text, nullable=True) - provider = db.Column(db.String(255), nullable=False, server_default=db.text("'vendor'::character varying")) - permission = db.Column(db.String(255), nullable=False, server_default=db.text("'only_me'::character varying")) - data_source_type = db.Column(db.String(255)) - indexing_technique = db.Column(db.String(255), nullable=True) - index_struct = db.Column(db.Text, nullable=True) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - embedding_model = db.Column(db.String(255), nullable=True) - embedding_model_provider = db.Column(db.String(255), nullable=True) - collection_binding_id = db.Column(StringUUID, nullable=True) - retrieval_model = db.Column(JSONB, nullable=True) - built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + name: Mapped[str] = mapped_column(db.String(255)) + description = mapped_column(db.Text, nullable=True) + provider: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'vendor'::character varying")) + permission: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'only_me'::character varying")) + data_source_type = mapped_column(db.String(255)) + indexing_technique: Mapped[Optional[str]] = mapped_column(db.String(255)) + index_struct = mapped_column(db.Text, nullable=True) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = mapped_column(StringUUID, nullable=True) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + embedding_model = db.Column(db.String(255), nullable=True) # TODO: mapped_column + embedding_model_provider = db.Column(db.String(255), nullable=True) # TODO: mapped_column + collection_binding_id = mapped_column(StringUUID, nullable=True) + retrieval_model = mapped_column(JSONB, nullable=True) + built_in_field_enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) @property def dataset_keyword_table(self): @@ -265,12 +266,12 @@ class DatasetProcessRule(Base): db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - dataset_id = db.Column(StringUUID, nullable=False) - mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) - rules = db.Column(db.Text, nullable=True) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + dataset_id = mapped_column(StringUUID, nullable=False) + mode = mapped_column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) + rules = mapped_column(db.Text, nullable=True) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) MODES = ["automatic", "custom", "hierarchical"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] @@ -309,62 +310,64 @@ class Document(Base): ) # initial fields - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - dataset_id = db.Column(StringUUID, nullable=False) - position = db.Column(db.Integer, nullable=False) - data_source_type = db.Column(db.String(255), nullable=False) - data_source_info = db.Column(db.Text, nullable=True) - dataset_process_rule_id = db.Column(StringUUID, nullable=True) - batch = db.Column(db.String(255), nullable=False) - name = db.Column(db.String(255), nullable=False) - created_from = db.Column(db.String(255), nullable=False) - created_by = db.Column(StringUUID, nullable=False) - created_api_request_id = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + dataset_id = mapped_column(StringUUID, nullable=False) + position = mapped_column(db.Integer, nullable=False) + data_source_type = mapped_column(db.String(255), nullable=False) + data_source_info = mapped_column(db.Text, nullable=True) + dataset_process_rule_id = mapped_column(StringUUID, nullable=True) + batch = mapped_column(db.String(255), nullable=False) + name = mapped_column(db.String(255), nullable=False) + created_from = mapped_column(db.String(255), nullable=False) + created_by = mapped_column(StringUUID, nullable=False) + created_api_request_id = mapped_column(StringUUID, nullable=True) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) # start processing - processing_started_at = db.Column(db.DateTime, nullable=True) + processing_started_at = mapped_column(db.DateTime, nullable=True) # parsing - file_id = db.Column(db.Text, nullable=True) - word_count = db.Column(db.Integer, nullable=True) - parsing_completed_at = db.Column(db.DateTime, nullable=True) + file_id = mapped_column(db.Text, nullable=True) + word_count = mapped_column(db.Integer, nullable=True) + parsing_completed_at = mapped_column(db.DateTime, nullable=True) # cleaning - cleaning_completed_at = db.Column(db.DateTime, nullable=True) + cleaning_completed_at = mapped_column(db.DateTime, nullable=True) # split - splitting_completed_at = db.Column(db.DateTime, nullable=True) + splitting_completed_at = mapped_column(db.DateTime, nullable=True) # indexing - tokens = db.Column(db.Integer, nullable=True) - indexing_latency = db.Column(db.Float, nullable=True) - completed_at = db.Column(db.DateTime, nullable=True) + tokens = mapped_column(db.Integer, nullable=True) + indexing_latency = mapped_column(db.Float, nullable=True) + completed_at = mapped_column(db.DateTime, nullable=True) # pause - is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) - paused_by = db.Column(StringUUID, nullable=True) - paused_at = db.Column(db.DateTime, nullable=True) + is_paused = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) + paused_by = mapped_column(StringUUID, nullable=True) + paused_at = mapped_column(db.DateTime, nullable=True) # error - error = db.Column(db.Text, nullable=True) - stopped_at = db.Column(db.DateTime, nullable=True) + error = mapped_column(db.Text, nullable=True) + stopped_at = mapped_column(db.DateTime, nullable=True) # basic fields - indexing_status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) - enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - disabled_at = db.Column(db.DateTime, nullable=True) - disabled_by = db.Column(StringUUID, nullable=True) - archived = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - archived_reason = db.Column(db.String(255), nullable=True) - archived_by = db.Column(StringUUID, nullable=True) - archived_at = db.Column(db.DateTime, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - doc_type = db.Column(db.String(40), nullable=True) - doc_metadata = db.Column(JSONB, nullable=True) - doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) - doc_language = db.Column(db.String(255), nullable=True) + indexing_status = mapped_column( + db.String(255), nullable=False, server_default=db.text("'waiting'::character varying") + ) + enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + disabled_at = mapped_column(db.DateTime, nullable=True) + disabled_by = mapped_column(StringUUID, nullable=True) + archived = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + archived_reason = mapped_column(db.String(255), nullable=True) + archived_by = mapped_column(StringUUID, nullable=True) + archived_at = mapped_column(db.DateTime, nullable=True) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + doc_type = mapped_column(db.String(40), nullable=True) + doc_metadata = mapped_column(JSONB, nullable=True) + doc_form = mapped_column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) + doc_language = mapped_column(db.String(255), nullable=True) DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] @@ -652,35 +655,35 @@ class DocumentSegment(Base): ) # initial fields - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - dataset_id = db.Column(StringUUID, nullable=False) - document_id = db.Column(StringUUID, nullable=False) + id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + dataset_id = mapped_column(StringUUID, nullable=False) + document_id = mapped_column(StringUUID, nullable=False) position: Mapped[int] - content = db.Column(db.Text, nullable=False) - answer = db.Column(db.Text, nullable=True) - word_count = db.Column(db.Integer, nullable=False) - tokens = db.Column(db.Integer, nullable=False) + content = mapped_column(db.Text, nullable=False) + answer = mapped_column(db.Text, nullable=True) + word_count: Mapped[int] + tokens: Mapped[int] # indexing fields - keywords = db.Column(db.JSON, nullable=True) - index_node_id = db.Column(db.String(255), nullable=True) - index_node_hash = db.Column(db.String(255), nullable=True) + keywords = mapped_column(db.JSON, nullable=True) + index_node_id = mapped_column(db.String(255), nullable=True) + index_node_hash = mapped_column(db.String(255), nullable=True) # basic fields - hit_count = db.Column(db.Integer, nullable=False, default=0) - enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - disabled_at = db.Column(db.DateTime, nullable=True) - disabled_by = db.Column(StringUUID, nullable=True) - status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - indexing_at = db.Column(db.DateTime, nullable=True) - completed_at = db.Column(db.DateTime, nullable=True) - error = db.Column(db.Text, nullable=True) - stopped_at = db.Column(db.DateTime, nullable=True) + hit_count = mapped_column(db.Integer, nullable=False, default=0) + enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + disabled_at = mapped_column(db.DateTime, nullable=True) + disabled_by = mapped_column(StringUUID, nullable=True) + status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'waiting'::character varying")) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = mapped_column(StringUUID, nullable=True) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + indexing_at = mapped_column(db.DateTime, nullable=True) + completed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) + error = mapped_column(db.Text, nullable=True) + stopped_at = mapped_column(db.DateTime, nullable=True) @property def dataset(self): @@ -800,25 +803,25 @@ class ChildChunk(Base): ) # initial fields - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - dataset_id = db.Column(StringUUID, nullable=False) - document_id = db.Column(StringUUID, nullable=False) - segment_id = db.Column(StringUUID, nullable=False) - position = db.Column(db.Integer, nullable=False) - content = db.Column(db.Text, nullable=False) - word_count = db.Column(db.Integer, nullable=False) + id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + dataset_id = mapped_column(StringUUID, nullable=False) + document_id = mapped_column(StringUUID, nullable=False) + segment_id = mapped_column(StringUUID, nullable=False) + position = mapped_column(db.Integer, nullable=False) + content = mapped_column(db.Text, nullable=False) + word_count = mapped_column(db.Integer, nullable=False) # indexing fields - index_node_id = db.Column(db.String(255), nullable=True) - index_node_hash = db.Column(db.String(255), nullable=True) - type = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - indexing_at = db.Column(db.DateTime, nullable=True) - completed_at = db.Column(db.DateTime, nullable=True) - error = db.Column(db.Text, nullable=True) + index_node_id = mapped_column(db.String(255), nullable=True) + index_node_hash = mapped_column(db.String(255), nullable=True) + type = mapped_column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_by = mapped_column(StringUUID, nullable=True) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + indexing_at = mapped_column(db.DateTime, nullable=True) + completed_at = mapped_column(db.DateTime, nullable=True) + error = mapped_column(db.Text, nullable=True) @property def dataset(self): @@ -840,10 +843,10 @@ class AppDatasetJoin(Base): db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"), ) - id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - dataset_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) + dataset_id = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) @property def app(self): @@ -857,14 +860,14 @@ class DatasetQuery(Base): db.Index("dataset_query_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) - dataset_id = db.Column(StringUUID, nullable=False) - content = db.Column(db.Text, nullable=False) - source = db.Column(db.String(255), nullable=False) - source_app_id = db.Column(StringUUID, nullable=True) - created_by_role = db.Column(db.String, nullable=False) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) + dataset_id = mapped_column(StringUUID, nullable=False) + content = mapped_column(db.Text, nullable=False) + source = mapped_column(db.String(255), nullable=False) + source_app_id = mapped_column(StringUUID, nullable=True) + created_by_role = mapped_column(db.String, nullable=False) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) class DatasetKeywordTable(Base): @@ -874,10 +877,10 @@ class DatasetKeywordTable(Base): db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) - dataset_id = db.Column(StringUUID, nullable=False, unique=True) - keyword_table = db.Column(db.Text, nullable=False) - data_source_type = db.Column( + id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + dataset_id = mapped_column(StringUUID, nullable=False, unique=True) + keyword_table = mapped_column(db.Text, nullable=False) + data_source_type = mapped_column( db.String(255), nullable=False, server_default=db.text("'database'::character varying") ) @@ -920,14 +923,14 @@ class Embedding(Base): db.Index("created_at_idx", "created_at"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) - model_name = db.Column( + id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + model_name = mapped_column( db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying") ) - hash = db.Column(db.String(64), nullable=False) - embedding = db.Column(db.LargeBinary, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying")) + hash = mapped_column(db.String(64), nullable=False) + embedding = mapped_column(db.LargeBinary, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name = mapped_column(db.String(255), nullable=False, server_default=db.text("''::character varying")) def set_embedding(self, embedding_data: list[float]): self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) @@ -943,12 +946,12 @@ class DatasetCollectionBinding(Base): db.Index("provider_model_name_idx", "provider_name", "model_name"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) - provider_name = db.Column(db.String(255), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) - collection_name = db.Column(db.String(64), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + provider_name = mapped_column(db.String(255), nullable=False) + model_name = mapped_column(db.String(255), nullable=False) + type = mapped_column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) + collection_name = mapped_column(db.String(64), nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TidbAuthBinding(Base): @@ -960,15 +963,15 @@ class TidbAuthBinding(Base): db.Index("tidb_auth_bindings_created_at_idx", "created_at"), db.Index("tidb_auth_bindings_status_idx", "status"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=True) - cluster_id = db.Column(db.String(255), nullable=False) - cluster_name = db.Column(db.String(255), nullable=False) - active = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING")) - account = db.Column(db.String(255), nullable=False) - password = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=True) + cluster_id = mapped_column(db.String(255), nullable=False) + cluster_name = mapped_column(db.String(255), nullable=False) + active = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + status = mapped_column(db.String(255), nullable=False, server_default=db.text("CREATING")) + account = mapped_column(db.String(255), nullable=False) + password = mapped_column(db.String(255), nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class Whitelist(Base): @@ -977,10 +980,10 @@ class Whitelist(Base): db.PrimaryKeyConstraint("id", name="whitelists_pkey"), db.Index("whitelists_tenant_idx", "tenant_id"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=True) - category = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=True) + category = mapped_column(db.String(255), nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class DatasetPermission(Base): @@ -992,12 +995,12 @@ class DatasetPermission(Base): db.Index("idx_dataset_permissions_tenant_id", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True) - dataset_id = db.Column(StringUUID, nullable=False) - account_id = db.Column(StringUUID, nullable=False) - tenant_id = db.Column(StringUUID, nullable=False) - has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True) + dataset_id = mapped_column(StringUUID, nullable=False) + account_id = mapped_column(StringUUID, nullable=False) + tenant_id = mapped_column(StringUUID, nullable=False) + has_permission = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ExternalKnowledgeApis(Base): @@ -1008,15 +1011,15 @@ class ExternalKnowledgeApis(Base): db.Index("external_knowledge_apis_name_idx", "name"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - name = db.Column(db.String(255), nullable=False) - description = db.Column(db.String(255), nullable=False) - tenant_id = db.Column(StringUUID, nullable=False) - settings = db.Column(db.Text, nullable=True) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + name = mapped_column(db.String(255), nullable=False) + description = mapped_column(db.String(255), nullable=False) + tenant_id = mapped_column(StringUUID, nullable=False) + settings = mapped_column(db.Text, nullable=True) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = mapped_column(StringUUID, nullable=True) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def to_dict(self): return { @@ -1063,15 +1066,15 @@ class ExternalKnowledgeBindings(Base): db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - external_knowledge_api_id = db.Column(StringUUID, nullable=False) - dataset_id = db.Column(StringUUID, nullable=False) - external_knowledge_id = db.Column(db.Text, nullable=False) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + external_knowledge_api_id = mapped_column(StringUUID, nullable=False) + dataset_id = mapped_column(StringUUID, nullable=False) + external_knowledge_id = mapped_column(db.Text, nullable=False) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = mapped_column(StringUUID, nullable=True) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class DatasetAutoDisableLog(Base): @@ -1083,12 +1086,12 @@ class DatasetAutoDisableLog(Base): db.Index("dataset_auto_disable_log_created_atx", "created_at"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - dataset_id = db.Column(StringUUID, nullable=False) - document_id = db.Column(StringUUID, nullable=False) - notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + dataset_id = mapped_column(StringUUID, nullable=False) + document_id = mapped_column(StringUUID, nullable=False) + notified = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class RateLimitLog(Base): @@ -1099,11 +1102,11 @@ class RateLimitLog(Base): db.Index("rate_limit_log_operation_idx", "operation"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - subscription_plan = db.Column(db.String(255), nullable=False) - operation = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + subscription_plan = mapped_column(db.String(255), nullable=False) + operation = mapped_column(db.String(255), nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class DatasetMetadata(Base): @@ -1114,15 +1117,15 @@ class DatasetMetadata(Base): db.Index("dataset_metadata_dataset_idx", "dataset_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - dataset_id = db.Column(StringUUID, nullable=False) - type = db.Column(db.String(255), nullable=False) - name = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - created_by = db.Column(StringUUID, nullable=False) - updated_by = db.Column(StringUUID, nullable=True) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + dataset_id = mapped_column(StringUUID, nullable=False) + type = mapped_column(db.String(255), nullable=False) + name = mapped_column(db.String(255), nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_by = mapped_column(StringUUID, nullable=False) + updated_by = mapped_column(StringUUID, nullable=True) class DatasetMetadataBinding(Base): @@ -1135,10 +1138,10 @@ class DatasetMetadataBinding(Base): db.Index("dataset_metadata_binding_document_idx", "document_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - dataset_id = db.Column(StringUUID, nullable=False) - metadata_id = db.Column(StringUUID, nullable=False) - document_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - created_by = db.Column(StringUUID, nullable=False) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + dataset_id = mapped_column(StringUUID, nullable=False) + metadata_id = mapped_column(StringUUID, nullable=False) + document_id = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_by = mapped_column(StringUUID, nullable=False) diff --git a/api/models/model.py b/api/models/model.py index 2377aeed8..b8e8b7801 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -40,8 +40,8 @@ class DifySetup(Base): __tablename__ = "dify_setups" __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) - version = db.Column(db.String(255), nullable=False) - setup_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + version = mapped_column(db.String(255), nullable=False) + setup_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class AppMode(StrEnum): @@ -74,31 +74,31 @@ class App(Base): __tablename__ = "apps" __table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id")) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) - name = db.Column(db.String(255), nullable=False) - description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) - mode: Mapped[str] = mapped_column(db.String(255), nullable=False) - icon_type = db.Column(db.String(255), nullable=True) # image, emoji + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + name: Mapped[str] = mapped_column(db.String(255)) + description: Mapped[str] = mapped_column(db.Text, server_default=db.text("''::character varying")) + mode: Mapped[str] = mapped_column(db.String(255)) + icon_type: Mapped[Optional[str]] = mapped_column(db.String(255)) # image, emoji icon = db.Column(db.String(255)) - icon_background = db.Column(db.String(255)) - app_model_config_id = db.Column(StringUUID, nullable=True) - workflow_id = db.Column(StringUUID, nullable=True) - status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) - enable_site = db.Column(db.Boolean, nullable=False) - enable_api = db.Column(db.Boolean, nullable=False) - api_rpm = db.Column(db.Integer, nullable=False, server_default=db.text("0")) - api_rph = db.Column(db.Integer, nullable=False, server_default=db.text("0")) - is_demo = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - tracing = db.Column(db.Text, nullable=True) - max_active_requests: Mapped[Optional[int]] = mapped_column(nullable=True) - created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + icon_background: Mapped[Optional[str]] = mapped_column(db.String(255)) + app_model_config_id = mapped_column(StringUUID, nullable=True) + workflow_id = mapped_column(StringUUID, nullable=True) + status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'normal'::character varying")) + enable_site: Mapped[bool] = mapped_column(db.Boolean) + enable_api: Mapped[bool] = mapped_column(db.Boolean) + api_rpm: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0")) + api_rph: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0")) + is_demo: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) + is_public: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) + is_universal: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) + tracing = mapped_column(db.Text, nullable=True) + max_active_requests: Mapped[Optional[int]] + created_by = mapped_column(StringUUID, nullable=True) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = mapped_column(StringUUID, nullable=True) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + use_icon_as_answer_icon: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) @property def desc_or_prompt(self): @@ -307,34 +307,34 @@ class AppModelConfig(Base): __tablename__ = "app_model_configs" __table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id")) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - provider = db.Column(db.String(255), nullable=True) - model_id = db.Column(db.String(255), nullable=True) - configs = db.Column(db.JSON, nullable=True) - created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - opening_statement = db.Column(db.Text) - suggested_questions = db.Column(db.Text) - suggested_questions_after_answer = db.Column(db.Text) - speech_to_text = db.Column(db.Text) - text_to_speech = db.Column(db.Text) - more_like_this = db.Column(db.Text) - model = db.Column(db.Text) - user_input_form = db.Column(db.Text) - dataset_query_variable = db.Column(db.String(255)) - pre_prompt = db.Column(db.Text) - agent_mode = db.Column(db.Text) - sensitive_word_avoidance = db.Column(db.Text) - retriever_resource = db.Column(db.Text) - prompt_type = db.Column(db.String(255), nullable=False, server_default=db.text("'simple'::character varying")) - chat_prompt_config = db.Column(db.Text) - completion_prompt_config = db.Column(db.Text) - dataset_configs = db.Column(db.Text) - external_data_tools = db.Column(db.Text) - file_upload = db.Column(db.Text) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) + provider = mapped_column(db.String(255), nullable=True) + model_id = mapped_column(db.String(255), nullable=True) + configs = mapped_column(db.JSON, nullable=True) + created_by = mapped_column(StringUUID, nullable=True) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = mapped_column(StringUUID, nullable=True) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + opening_statement = mapped_column(db.Text) + suggested_questions = mapped_column(db.Text) + suggested_questions_after_answer = mapped_column(db.Text) + speech_to_text = mapped_column(db.Text) + text_to_speech = mapped_column(db.Text) + more_like_this = mapped_column(db.Text) + model = mapped_column(db.Text) + user_input_form = mapped_column(db.Text) + dataset_query_variable = mapped_column(db.String(255)) + pre_prompt = mapped_column(db.Text) + agent_mode = mapped_column(db.Text) + sensitive_word_avoidance = mapped_column(db.Text) + retriever_resource = mapped_column(db.Text) + prompt_type = mapped_column(db.String(255), nullable=False, server_default=db.text("'simple'::character varying")) + chat_prompt_config = mapped_column(db.Text) + completion_prompt_config = mapped_column(db.Text) + dataset_configs = mapped_column(db.Text) + external_data_tools = mapped_column(db.Text) + file_upload = mapped_column(db.Text) @property def app(self): @@ -561,19 +561,19 @@ class RecommendedApp(Base): db.Index("recommended_app_is_listed_idx", "is_listed", "language"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - description = db.Column(db.JSON, nullable=False) - copyright = db.Column(db.String(255), nullable=False) - privacy_policy = db.Column(db.String(255), nullable=False) + id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) + description = mapped_column(db.JSON, nullable=False) + copyright = mapped_column(db.String(255), nullable=False) + privacy_policy = mapped_column(db.String(255), nullable=False) custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") - category = db.Column(db.String(255), nullable=False) - position = db.Column(db.Integer, nullable=False, default=0) - is_listed = db.Column(db.Boolean, nullable=False, default=True) - install_count = db.Column(db.Integer, nullable=False, default=0) - language = db.Column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying")) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + category = mapped_column(db.String(255), nullable=False) + position = mapped_column(db.Integer, nullable=False, default=0) + is_listed = mapped_column(db.Boolean, nullable=False, default=True) + install_count = mapped_column(db.Integer, nullable=False, default=0) + language = mapped_column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying")) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def app(self): @@ -590,14 +590,14 @@ class InstalledApp(Base): db.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=False) - app_owner_tenant_id = db.Column(StringUUID, nullable=False) - position = db.Column(db.Integer, nullable=False, default=0) - is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - last_used_at = db.Column(db.DateTime, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + app_id = mapped_column(StringUUID, nullable=False) + app_owner_tenant_id = mapped_column(StringUUID, nullable=False) + position = mapped_column(db.Integer, nullable=False, default=0) + is_pinned = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + last_used_at = mapped_column(db.DateTime, nullable=True) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def app(self): @@ -618,42 +618,42 @@ class Conversation(Base): ) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - app_model_config_id = db.Column(StringUUID, nullable=True) - model_provider = db.Column(db.String(255), nullable=True) - override_model_configs = db.Column(db.Text) - model_id = db.Column(db.String(255), nullable=True) + app_id = mapped_column(StringUUID, nullable=False) + app_model_config_id = mapped_column(StringUUID, nullable=True) + model_provider = mapped_column(db.String(255), nullable=True) + override_model_configs = mapped_column(db.Text) + model_id = mapped_column(db.String(255), nullable=True) mode: Mapped[str] = mapped_column(db.String(255)) - name = db.Column(db.String(255), nullable=False) - summary = db.Column(db.Text) + name = mapped_column(db.String(255), nullable=False) + summary = mapped_column(db.Text) _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) - introduction = db.Column(db.Text) - system_instruction = db.Column(db.Text) - system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) - status = db.Column(db.String(255), nullable=False) + introduction = mapped_column(db.Text) + system_instruction = mapped_column(db.Text) + system_instruction_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + status = mapped_column(db.String(255), nullable=False) # The `invoke_from` records how the conversation is created. # # Its value corresponds to the members of `InvokeFrom`. # (api/core/app/entities/app_invoke_entities.py) - invoke_from = db.Column(db.String(255), nullable=True) + invoke_from = mapped_column(db.String(255), nullable=True) # ref: ConversationSource. - from_source = db.Column(db.String(255), nullable=False) - from_end_user_id = db.Column(StringUUID) - from_account_id = db.Column(StringUUID) - read_at = db.Column(db.DateTime) - read_account_id = db.Column(StringUUID) + from_source = mapped_column(db.String(255), nullable=False) + from_end_user_id = mapped_column(StringUUID) + from_account_id = mapped_column(StringUUID) + read_at = mapped_column(db.DateTime) + read_account_id = mapped_column(StringUUID) dialogue_count: Mapped[int] = mapped_column(default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all") message_annotations = db.relationship( "MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all" ) - is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + is_deleted = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) @property def inputs(self): @@ -896,36 +896,36 @@ class Message(Base): ) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - model_provider = db.Column(db.String(255), nullable=True) - model_id = db.Column(db.String(255), nullable=True) - override_model_configs = db.Column(db.Text) - conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=False) + app_id = mapped_column(StringUUID, nullable=False) + model_provider = mapped_column(db.String(255), nullable=True) + model_id = mapped_column(db.String(255), nullable=True) + override_model_configs = mapped_column(db.Text) + conversation_id = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=False) _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) - query: Mapped[str] = db.Column(db.Text, nullable=False) - message = db.Column(db.JSON, nullable=False) - message_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0")) - message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) - message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - answer: Mapped[str] = db.Column(db.Text, nullable=False) - answer_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0")) - answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) - answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - parent_message_id = db.Column(StringUUID, nullable=True) - provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0")) - total_price = db.Column(db.Numeric(10, 7)) - currency = db.Column(db.String(255), nullable=False) - status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) - error = db.Column(db.Text) - message_metadata = db.Column(db.Text) - invoke_from: Mapped[Optional[str]] = db.Column(db.String(255), nullable=True) - from_source = db.Column(db.String(255), nullable=False) - from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID) - from_account_id: Mapped[Optional[str]] = db.Column(StringUUID) + query: Mapped[str] = mapped_column(db.Text, nullable=False) + message = mapped_column(db.JSON, nullable=False) + message_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + message_unit_price = mapped_column(db.Numeric(10, 4), nullable=False) + message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + answer: Mapped[str] = db.Column(db.Text, nullable=False) # TODO make it mapped_column + answer_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False) + answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + parent_message_id = mapped_column(StringUUID, nullable=True) + provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0")) + total_price = mapped_column(db.Numeric(10, 7)) + currency = mapped_column(db.String(255), nullable=False) + status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + error = mapped_column(db.Text) + message_metadata = mapped_column(db.Text) + invoke_from: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) + from_source = mapped_column(db.String(255), nullable=False) + from_end_user_id: Mapped[Optional[str]] = mapped_column(StringUUID) + from_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - workflow_run_id: Mapped[str] = db.Column(StringUUID) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + agent_based = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) @property def inputs(self): @@ -1239,17 +1239,17 @@ class MessageFeedback(Base): db.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - conversation_id = db.Column(StringUUID, nullable=False) - message_id = db.Column(StringUUID, nullable=False) - rating = db.Column(db.String(255), nullable=False) - content = db.Column(db.Text) - from_source = db.Column(db.String(255), nullable=False) - from_end_user_id = db.Column(StringUUID) - from_account_id = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) + conversation_id = mapped_column(StringUUID, nullable=False) + message_id = mapped_column(StringUUID, nullable=False) + rating = mapped_column(db.String(255), nullable=False) + content = mapped_column(db.Text) + from_source = mapped_column(db.String(255), nullable=False) + from_end_user_id = mapped_column(StringUUID) + from_account_id = mapped_column(StringUUID) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def from_account(self): @@ -1301,16 +1301,16 @@ class MessageFile(Base): self.created_by_role = created_by_role.value self.created_by = created_by - id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - message_id: Mapped[str] = db.Column(StringUUID, nullable=False) - type: Mapped[str] = db.Column(db.String(255), nullable=False) - transfer_method: Mapped[str] = db.Column(db.String(255), nullable=False) - url: Mapped[Optional[str]] = db.Column(db.Text, nullable=True) - belongs_to: Mapped[Optional[str]] = db.Column(db.String(255), nullable=True) - upload_file_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True) - created_by_role: Mapped[str] = db.Column(db.String(255), nullable=False) - created_by: Mapped[str] = db.Column(StringUUID, nullable=False) - created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + type: Mapped[str] = mapped_column(db.String(255), nullable=False) + transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False) + url: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) + belongs_to: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) + upload_file_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) + created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class MessageAnnotation(Base): @@ -1322,16 +1322,16 @@ class MessageAnnotation(Base): db.Index("message_annotation_message_idx", "message_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=True) - message_id = db.Column(StringUUID, nullable=True) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id: Mapped[str] = mapped_column(StringUUID) + conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, db.ForeignKey("conversations.id")) + message_id: Mapped[Optional[str]] = mapped_column(StringUUID) question = db.Column(db.Text, nullable=True) - content = db.Column(db.Text, nullable=False) - hit_count = db.Column(db.Integer, nullable=False, server_default=db.text("0")) - account_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + content = mapped_column(db.Text, nullable=False) + hit_count = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + account_id = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def account(self): @@ -1354,17 +1354,17 @@ class AppAnnotationHitHistory(Base): db.Index("app_annotation_hit_histories_message_idx", "message_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - annotation_id: Mapped[str] = db.Column(StringUUID, nullable=False) - source = db.Column(db.Text, nullable=False) - question = db.Column(db.Text, nullable=False) - account_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - score = db.Column(Float, nullable=False, server_default=db.text("0")) - message_id = db.Column(StringUUID, nullable=False) - annotation_question = db.Column(db.Text, nullable=False) - annotation_content = db.Column(db.Text, nullable=False) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) + annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + source = mapped_column(db.Text, nullable=False) + question = mapped_column(db.Text, nullable=False) + account_id = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + score = mapped_column(Float, nullable=False, server_default=db.text("0")) + message_id = mapped_column(StringUUID, nullable=False) + annotation_question = mapped_column(db.Text, nullable=False) + annotation_content = mapped_column(db.Text, nullable=False) @property def account(self): @@ -1389,14 +1389,14 @@ class AppAnnotationSetting(Base): db.Index("app_annotation_settings_app_idx", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - score_threshold = db.Column(Float, nullable=False, server_default=db.text("0")) - collection_binding_id = db.Column(StringUUID, nullable=False) - created_user_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_user_id = db.Column(StringUUID, nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) + score_threshold = mapped_column(Float, nullable=False, server_default=db.text("0")) + collection_binding_id = mapped_column(StringUUID, nullable=False) + created_user_id = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_user_id = mapped_column(StringUUID, nullable=False) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def collection_binding_detail(self): @@ -1417,14 +1417,14 @@ class OperationLog(Base): db.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - account_id = db.Column(StringUUID, nullable=False) - action = db.Column(db.String(255), nullable=False) - content = db.Column(db.JSON) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - created_ip = db.Column(db.String(255), nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + account_id = mapped_column(StringUUID, nullable=False) + action = mapped_column(db.String(255), nullable=False) + content = mapped_column(db.JSON) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_ip = mapped_column(db.String(255), nullable=False) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class EndUser(Base, UserMixin): @@ -1435,16 +1435,16 @@ class EndUser(Base, UserMixin): db.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=True) - type = db.Column(db.String(255), nullable=False) - external_user_id = db.Column(db.String(255), nullable=True) - name = db.Column(db.String(255)) - is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id = mapped_column(StringUUID, nullable=True) + type = mapped_column(db.String(255), nullable=False) + external_user_id = mapped_column(db.String(255), nullable=True) + name = mapped_column(db.String(255)) + is_anonymous = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) session_id: Mapped[str] = mapped_column() - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class AppMCPServer(Base): @@ -1454,17 +1454,17 @@ class AppMCPServer(Base): db.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"), db.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=False) - name = db.Column(db.String(255), nullable=False) - description = db.Column(db.String(255), nullable=False) - server_code = db.Column(db.String(255), nullable=False) - status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) - parameters = db.Column(db.Text, nullable=False) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + app_id = mapped_column(StringUUID, nullable=False) + name = mapped_column(db.String(255), nullable=False) + description = mapped_column(db.String(255), nullable=False) + server_code = mapped_column(db.String(255), nullable=False) + status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + parameters = mapped_column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @staticmethod def generate_server_code(n): @@ -1488,30 +1488,30 @@ class Site(Base): db.Index("site_code_idx", "code", "status"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - title = db.Column(db.String(255), nullable=False) - icon_type = db.Column(db.String(255), nullable=True) - icon = db.Column(db.String(255)) - icon_background = db.Column(db.String(255)) - description = db.Column(db.Text) - default_language = db.Column(db.String(255), nullable=False) - chat_color_theme = db.Column(db.String(255)) - chat_color_theme_inverted = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - copyright = db.Column(db.String(255)) - privacy_policy = db.Column(db.String(255)) - show_workflow_steps = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) + title = mapped_column(db.String(255), nullable=False) + icon_type = mapped_column(db.String(255), nullable=True) + icon = mapped_column(db.String(255)) + icon_background = mapped_column(db.String(255)) + description = mapped_column(db.Text) + default_language = mapped_column(db.String(255), nullable=False) + chat_color_theme = mapped_column(db.String(255)) + chat_color_theme_inverted = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + copyright = mapped_column(db.String(255)) + privacy_policy = mapped_column(db.String(255)) + show_workflow_steps = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + use_icon_as_answer_icon = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) _custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", sa.TEXT, default="") - customize_domain = db.Column(db.String(255)) - customize_token_strategy = db.Column(db.String(255), nullable=False) - prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) - created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - code = db.Column(db.String(255)) + customize_domain = mapped_column(db.String(255)) + customize_token_strategy = mapped_column(db.String(255), nullable=False) + prompt_public = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + created_by = mapped_column(StringUUID, nullable=True) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by = mapped_column(StringUUID, nullable=True) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + code = mapped_column(db.String(255)) @property def custom_disclaimer(self): @@ -1546,13 +1546,13 @@ class ApiToken(Base): db.Index("api_token_tenant_idx", "tenant_id", "type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=True) - tenant_id = db.Column(StringUUID, nullable=True) - type = db.Column(db.String(16), nullable=False) - token = db.Column(db.String(255), nullable=False) - last_used_at = db.Column(db.DateTime, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=True) + tenant_id = mapped_column(StringUUID, nullable=True) + type = mapped_column(db.String(16), nullable=False) + token = mapped_column(db.String(255), nullable=False) + last_used_at = mapped_column(db.DateTime, nullable=True) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @staticmethod def generate_api_key(prefix, n): @@ -1570,23 +1570,23 @@ class UploadFile(Base): db.Index("upload_file_tenant_idx", "tenant_id"), ) - id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) - storage_type: Mapped[str] = db.Column(db.String(255), nullable=False) - key: Mapped[str] = db.Column(db.String(255), nullable=False) - name: Mapped[str] = db.Column(db.String(255), nullable=False) - size: Mapped[int] = db.Column(db.Integer, nullable=False) - extension: Mapped[str] = db.Column(db.String(255), nullable=False) - mime_type: Mapped[str] = db.Column(db.String(255), nullable=True) - created_by_role: Mapped[str] = db.Column( + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + storage_type: Mapped[str] = mapped_column(db.String(255), nullable=False) + key: Mapped[str] = mapped_column(db.String(255), nullable=False) + name: Mapped[str] = mapped_column(db.String(255), nullable=False) + size: Mapped[int] = mapped_column(db.Integer, nullable=False) + extension: Mapped[str] = mapped_column(db.String(255), nullable=False) + mime_type: Mapped[str] = mapped_column(db.String(255), nullable=True) + created_by_role: Mapped[str] = mapped_column( db.String(255), nullable=False, server_default=db.text("'account'::character varying") ) - created_by: Mapped[str] = db.Column(StringUUID, nullable=False) - created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - used: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - used_by: Mapped[str | None] = db.Column(StringUUID, nullable=True) - used_at: Mapped[datetime | None] = db.Column(db.DateTime, nullable=True) - hash: Mapped[str | None] = db.Column(db.String(255), nullable=True) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + used: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + used_at: Mapped[datetime | None] = mapped_column(db.DateTime, nullable=True) + hash: Mapped[str | None] = mapped_column(db.String(255), nullable=True) source_url: Mapped[str] = mapped_column(sa.TEXT, default="") def __init__( @@ -1632,14 +1632,14 @@ class ApiRequest(Base): db.Index("api_request_token_idx", "tenant_id", "api_token_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - api_token_id = db.Column(StringUUID, nullable=False) - path = db.Column(db.String(255), nullable=False) - request = db.Column(db.Text, nullable=True) - response = db.Column(db.Text, nullable=True) - ip = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + api_token_id = mapped_column(StringUUID, nullable=False) + path = mapped_column(db.String(255), nullable=False) + request = mapped_column(db.Text, nullable=True) + response = mapped_column(db.Text, nullable=True) + ip = mapped_column(db.String(255), nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class MessageChain(Base): @@ -1649,12 +1649,12 @@ class MessageChain(Base): db.Index("message_chain_message_id_idx", "message_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - message_id = db.Column(StringUUID, nullable=False) - type = db.Column(db.String(255), nullable=False) - input = db.Column(db.Text, nullable=True) - output = db.Column(db.Text, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + message_id = mapped_column(StringUUID, nullable=False) + type = mapped_column(db.String(255), nullable=False) + input = mapped_column(db.Text, nullable=True) + output = mapped_column(db.Text, nullable=True) + created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) class MessageAgentThought(Base): @@ -1665,34 +1665,34 @@ class MessageAgentThought(Base): db.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - message_id = db.Column(StringUUID, nullable=False) - message_chain_id = db.Column(StringUUID, nullable=True) - position = db.Column(db.Integer, nullable=False) - thought = db.Column(db.Text, nullable=True) - tool = db.Column(db.Text, nullable=True) - tool_labels_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text")) - tool_meta_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text")) - tool_input = db.Column(db.Text, nullable=True) - observation = db.Column(db.Text, nullable=True) - # plugin_id = db.Column(StringUUID, nullable=True) ## for future design - tool_process_data = db.Column(db.Text, nullable=True) - message = db.Column(db.Text, nullable=True) - message_token = db.Column(db.Integer, nullable=True) - message_unit_price = db.Column(db.Numeric, nullable=True) - message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - message_files = db.Column(db.Text, nullable=True) + id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + message_id = mapped_column(StringUUID, nullable=False) + message_chain_id = mapped_column(StringUUID, nullable=True) + position = mapped_column(db.Integer, nullable=False) + thought = mapped_column(db.Text, nullable=True) + tool = mapped_column(db.Text, nullable=True) + tool_labels_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text")) + tool_meta_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text")) + tool_input = mapped_column(db.Text, nullable=True) + observation = mapped_column(db.Text, nullable=True) + # plugin_id = mapped_column(StringUUID, nullable=True) ## for future design + tool_process_data = mapped_column(db.Text, nullable=True) + message = mapped_column(db.Text, nullable=True) + message_token = mapped_column(db.Integer, nullable=True) + message_unit_price = mapped_column(db.Numeric, nullable=True) + message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + message_files = mapped_column(db.Text, nullable=True) answer = db.Column(db.Text, nullable=True) - answer_token = db.Column(db.Integer, nullable=True) - answer_unit_price = db.Column(db.Numeric, nullable=True) - answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - tokens = db.Column(db.Integer, nullable=True) - total_price = db.Column(db.Numeric, nullable=True) - currency = db.Column(db.String, nullable=True) - latency = db.Column(db.Float, nullable=True) - created_by_role = db.Column(db.String, nullable=False) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + answer_token = mapped_column(db.Integer, nullable=True) + answer_unit_price = mapped_column(db.Numeric, nullable=True) + answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + tokens = mapped_column(db.Integer, nullable=True) + total_price = mapped_column(db.Numeric, nullable=True) + currency = mapped_column(db.String, nullable=True) + latency = mapped_column(db.Float, nullable=True) + created_by_role = mapped_column(db.String, nullable=False) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) @property def files(self) -> list: @@ -1778,24 +1778,24 @@ class DatasetRetrieverResource(Base): db.Index("dataset_retriever_resource_message_id_idx", "message_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - message_id = db.Column(StringUUID, nullable=False) - position = db.Column(db.Integer, nullable=False) - dataset_id = db.Column(StringUUID, nullable=False) - dataset_name = db.Column(db.Text, nullable=False) - document_id = db.Column(StringUUID, nullable=True) - document_name = db.Column(db.Text, nullable=False) - data_source_type = db.Column(db.Text, nullable=True) - segment_id = db.Column(StringUUID, nullable=True) - score = db.Column(db.Float, nullable=True) - content = db.Column(db.Text, nullable=False) - hit_count = db.Column(db.Integer, nullable=True) - word_count = db.Column(db.Integer, nullable=True) - segment_position = db.Column(db.Integer, nullable=True) - index_node_hash = db.Column(db.Text, nullable=True) - retriever_from = db.Column(db.Text, nullable=False) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + message_id = mapped_column(StringUUID, nullable=False) + position = mapped_column(db.Integer, nullable=False) + dataset_id = mapped_column(StringUUID, nullable=False) + dataset_name = mapped_column(db.Text, nullable=False) + document_id = mapped_column(StringUUID, nullable=True) + document_name = mapped_column(db.Text, nullable=False) + data_source_type = mapped_column(db.Text, nullable=True) + segment_id = mapped_column(StringUUID, nullable=True) + score = mapped_column(db.Float, nullable=True) + content = mapped_column(db.Text, nullable=False) + hit_count = mapped_column(db.Integer, nullable=True) + word_count = mapped_column(db.Integer, nullable=True) + segment_position = mapped_column(db.Integer, nullable=True) + index_node_hash = mapped_column(db.Text, nullable=True) + retriever_from = mapped_column(db.Text, nullable=False) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) class Tag(Base): @@ -1808,12 +1808,12 @@ class Tag(Base): TAG_TYPE_LIST = ["knowledge", "app"] - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=True) - type = db.Column(db.String(16), nullable=False) - name = db.Column(db.String(255), nullable=False) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=True) + type = mapped_column(db.String(16), nullable=False) + name = mapped_column(db.String(255), nullable=False) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TagBinding(Base): @@ -1824,12 +1824,12 @@ class TagBinding(Base): db.Index("tag_bind_tag_id_idx", "tag_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=True) - tag_id = db.Column(StringUUID, nullable=True) - target_id = db.Column(StringUUID, nullable=True) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=True) + tag_id = mapped_column(StringUUID, nullable=True) + target_id = mapped_column(StringUUID, nullable=True) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TraceAppConfig(Base): @@ -1839,15 +1839,15 @@ class TraceAppConfig(Base): db.Index("trace_app_config_app_id_idx", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - tracing_provider = db.Column(db.String(255), nullable=True) - tracing_config = db.Column(db.JSON, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column( + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) + tracing_provider = mapped_column(db.String(255), nullable=True) + tracing_config = mapped_column(db.JSON, nullable=True) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column( db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) - is_active = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + is_active = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) @property def tracing_config_dict(self): diff --git a/api/models/source.py b/api/models/source.py index f6e0900ae..100e0d96e 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -2,6 +2,7 @@ import json from sqlalchemy import func from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import mapped_column from models.base import Base @@ -17,14 +18,14 @@ class DataSourceOauthBinding(Base): db.Index("source_info_idx", "source_info", postgresql_using="gin"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - access_token = db.Column(db.String(255), nullable=False) - provider = db.Column(db.String(255), nullable=False) - source_info = db.Column(JSONB, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + access_token = mapped_column(db.String(255), nullable=False) + provider = mapped_column(db.String(255), nullable=False) + source_info = mapped_column(JSONB, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + disabled = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) class DataSourceApiKeyAuthBinding(Base): @@ -35,14 +36,14 @@ class DataSourceApiKeyAuthBinding(Base): db.Index("data_source_api_key_auth_binding_provider_idx", "provider"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - category = db.Column(db.String(255), nullable=False) - provider = db.Column(db.String(255), nullable=False) - credentials = db.Column(db.Text, nullable=True) # JSON - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + category = mapped_column(db.String(255), nullable=False) + provider = mapped_column(db.String(255), nullable=False) + credentials = mapped_column(db.Text, nullable=True) # JSON + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + disabled = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) def to_dict(self): return { diff --git a/api/models/task.py b/api/models/task.py index 1a4b606ff..3e5ebd209 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -1,4 +1,8 @@ +from datetime import datetime +from typing import Optional + from celery import states # type: ignore +from sqlalchemy.orm import Mapped, mapped_column from libs.datetime_utils import naive_utc_now from models.base import Base @@ -11,23 +15,23 @@ class CeleryTask(Base): __tablename__ = "celery_taskmeta" - id = db.Column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) - task_id = db.Column(db.String(155), unique=True) - status = db.Column(db.String(50), default=states.PENDING) - result = db.Column(db.PickleType, nullable=True) - date_done = db.Column( + id = mapped_column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) + task_id = mapped_column(db.String(155), unique=True) + status = mapped_column(db.String(50), default=states.PENDING) + result = mapped_column(db.PickleType, nullable=True) + date_done = mapped_column( db.DateTime, default=lambda: naive_utc_now(), onupdate=lambda: naive_utc_now(), nullable=True, ) - traceback = db.Column(db.Text, nullable=True) - name = db.Column(db.String(155), nullable=True) - args = db.Column(db.LargeBinary, nullable=True) - kwargs = db.Column(db.LargeBinary, nullable=True) - worker = db.Column(db.String(155), nullable=True) - retries = db.Column(db.Integer, nullable=True) - queue = db.Column(db.String(155), nullable=True) + traceback = mapped_column(db.Text, nullable=True) + name = mapped_column(db.String(155), nullable=True) + args = mapped_column(db.LargeBinary, nullable=True) + kwargs = mapped_column(db.LargeBinary, nullable=True) + worker = mapped_column(db.String(155), nullable=True) + retries = mapped_column(db.Integer, nullable=True) + queue = mapped_column(db.String(155), nullable=True) class CeleryTaskSet(Base): @@ -35,7 +39,9 @@ class CeleryTaskSet(Base): __tablename__ = "celery_tasksetmeta" - id = db.Column(db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True) - taskset_id = db.Column(db.String(155), unique=True) - result = db.Column(db.PickleType, nullable=True) - date_done = db.Column(db.DateTime, default=lambda: naive_utc_now(), nullable=True) + id: Mapped[int] = mapped_column( + db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True + ) + taskset_id = mapped_column(db.String(155), unique=True) + result = mapped_column(db.PickleType, nullable=True) + date_done: Mapped[Optional[datetime]] = mapped_column(db.DateTime, default=lambda: naive_utc_now(), nullable=True) diff --git a/api/models/tools.py b/api/models/tools.py index f5fae8b79..a0b7e5417 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -110,26 +110,26 @@ class ApiToolProvider(Base): db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the api provider - name = db.Column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying")) + name = mapped_column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying")) # icon - icon = db.Column(db.String(255), nullable=False) + icon = mapped_column(db.String(255), nullable=False) # original schema - schema = db.Column(db.Text, nullable=False) - schema_type_str: Mapped[str] = db.Column(db.String(40), nullable=False) + schema = mapped_column(db.Text, nullable=False) + schema_type_str: Mapped[str] = mapped_column(db.String(40), nullable=False) # who created this tool - user_id = db.Column(StringUUID, nullable=False) + user_id = mapped_column(StringUUID, nullable=False) # tenant id - tenant_id = db.Column(StringUUID, nullable=False) + tenant_id = mapped_column(StringUUID, nullable=False) # description of the provider - description = db.Column(db.Text, nullable=False) + description = mapped_column(db.Text, nullable=False) # json format tools - tools_str = db.Column(db.Text, nullable=False) + tools_str = mapped_column(db.Text, nullable=False) # json format credentials - credentials_str = db.Column(db.Text, nullable=False) + credentials_str = mapped_column(db.Text, nullable=False) # privacy policy - privacy_policy = db.Column(db.String(255), nullable=True) + privacy_policy = mapped_column(db.String(255), nullable=True) # custom_disclaimer custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") @@ -348,33 +348,33 @@ class ToolModelInvoke(Base): __tablename__ = "tool_model_invokes" __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # who invoke this tool - user_id = db.Column(StringUUID, nullable=False) + user_id = mapped_column(StringUUID, nullable=False) # tenant id - tenant_id = db.Column(StringUUID, nullable=False) + tenant_id = mapped_column(StringUUID, nullable=False) # provider - provider = db.Column(db.String(255), nullable=False) + provider = mapped_column(db.String(255), nullable=False) # type - tool_type = db.Column(db.String(40), nullable=False) + tool_type = mapped_column(db.String(40), nullable=False) # tool name - tool_name = db.Column(db.String(128), nullable=False) + tool_name = mapped_column(db.String(128), nullable=False) # invoke parameters - model_parameters = db.Column(db.Text, nullable=False) + model_parameters = mapped_column(db.Text, nullable=False) # prompt messages - prompt_messages = db.Column(db.Text, nullable=False) + prompt_messages = mapped_column(db.Text, nullable=False) # invoke response - model_response = db.Column(db.Text, nullable=False) + model_response = mapped_column(db.Text, nullable=False) - prompt_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) - answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) - answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) - answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0")) - total_price = db.Column(db.Numeric(10, 7)) - currency = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + prompt_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + answer_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False) + answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0")) + total_price = mapped_column(db.Numeric(10, 7)) + currency = mapped_column(db.String(255), nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @deprecated @@ -391,18 +391,18 @@ class ToolConversationVariables(Base): db.Index("conversation_id_idx", "conversation_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # conversation user id - user_id = db.Column(StringUUID, nullable=False) + user_id = mapped_column(StringUUID, nullable=False) # tenant id - tenant_id = db.Column(StringUUID, nullable=False) + tenant_id = mapped_column(StringUUID, nullable=False) # conversation id - conversation_id = db.Column(StringUUID, nullable=False) + conversation_id = mapped_column(StringUUID, nullable=False) # variables pool - variables_str = db.Column(db.Text, nullable=False) + variables_str = mapped_column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def variables(self) -> Any: @@ -451,26 +451,26 @@ class DeprecatedPublishedAppTool(Base): db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # id of the app - app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False) + app_id = mapped_column(StringUUID, ForeignKey("apps.id"), nullable=False) - user_id: Mapped[str] = db.Column(StringUUID, nullable=False) + user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # who published this tool - description = db.Column(db.Text, nullable=False) + description = mapped_column(db.Text, nullable=False) # llm_description of the tool, for LLM - llm_description = db.Column(db.Text, nullable=False) + llm_description = mapped_column(db.Text, nullable=False) # query description, query will be seem as a parameter of the tool, # to describe this parameter to llm, we need this field - query_description = db.Column(db.Text, nullable=False) + query_description = mapped_column(db.Text, nullable=False) # query name, the name of the query parameter - query_name = db.Column(db.String(40), nullable=False) + query_name = mapped_column(db.String(40), nullable=False) # name of the tool provider - tool_name = db.Column(db.String(40), nullable=False) + tool_name = mapped_column(db.String(40), nullable=False) # author - author = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + author = mapped_column(db.String(40), nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def description_i18n(self) -> I18nObject: diff --git a/api/models/web.py b/api/models/web.py index fe2f0c47f..bcc95ddbc 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -15,12 +15,14 @@ class SavedMessage(Base): db.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) - message_id = db.Column(StringUUID, nullable=False) - created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) + message_id = mapped_column(StringUUID, nullable=False) + created_by_role = mapped_column( + db.String(255), nullable=False, server_default=db.text("'end_user'::character varying") + ) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def message(self): @@ -34,9 +36,11 @@ class PinnedConversation(Base): db.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - app_id = db.Column(StringUUID, nullable=False) + id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) conversation_id: Mapped[str] = mapped_column(StringUUID) - created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_by_role = mapped_column( + db.String(255), nullable=False, server_default=db.text("'end_user'::character varying") + ) + created_by = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index c0cd42a22..be228a6d9 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -99,9 +99,7 @@ def clean_unused_datasets_task(): index_processor.clean(dataset, None) # update document - update_params = {Document.enabled: False} - - db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params) + db.session.query(Document).filter_by(dataset_id=dataset.id).update({Document.enabled: False}) db.session.commit() click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")) except Exception as e: @@ -176,9 +174,7 @@ def clean_unused_datasets_task(): index_processor.clean(dataset, None) # update document - update_params = {Document.enabled: False} - - db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params) + db.session.query(Document).filter_by(dataset_id=dataset.id).update({Document.enabled: False}) db.session.commit() click.echo( click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green") diff --git a/api/services/account_service.py b/api/services/account_service.py index 352efb2f0..c88e70e38 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -998,7 +998,7 @@ class TenantService: .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) .first() ) - return join.role if join else None + return TenantAccountRole(join.role) if join else None @staticmethod def get_tenant_count() -> int: diff --git a/api/services/app_service.py b/api/services/app_service.py index 3494b2796..cfcb414de 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,6 +1,6 @@ import json import logging -from typing import Optional, cast +from typing import Optional, TypedDict, cast from flask_login import current_user from flask_sqlalchemy.pagination import Pagination @@ -220,18 +220,27 @@ class AppService: return app - def update_app(self, app: App, args: dict) -> App: + class ArgsDict(TypedDict): + name: str + description: str + icon_type: str + icon: str + icon_background: str + use_icon_as_answer_icon: bool + max_active_requests: int + + def update_app(self, app: App, args: ArgsDict) -> App: """ Update app :param app: App instance :param args: request args :return: App instance """ - app.name = args.get("name") - app.description = args.get("description", "") - app.icon_type = args.get("icon_type", "emoji") - app.icon = args.get("icon") - app.icon_background = args.get("icon_background") + app.name = args["name"] + app.description = args["description"] + app.icon_type = args["icon_type"] + app.icon = args["icon"] + app.icon_background = args["icon_background"] app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) app.max_active_requests = args.get("max_active_requests") app.updated_by = current_user.id diff --git a/api/services/billing_service.py b/api/services/billing_service.py index d44483ad8..9fffde073 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -82,7 +82,7 @@ class BillingService: if not join: raise ValueError("Tenant account join not found") - if not TenantAccountRole.is_privileged_role(join.role): + if not TenantAccountRole.is_privileged_role(TenantAccountRole(join.role)): raise ValueError("Only team owner or team admin can perform this action") @classmethod diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 09cdd66e0..ce597420d 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -215,9 +215,9 @@ class DatasetService: dataset.created_by = account.id dataset.updated_by = account.id dataset.tenant_id = tenant_id - dataset.embedding_model_provider = embedding_model.provider if embedding_model else None - dataset.embedding_model = embedding_model.model if embedding_model else None - dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None + dataset.embedding_model_provider = embedding_model.provider if embedding_model else None # type: ignore + dataset.embedding_model = embedding_model.model if embedding_model else None # type: ignore + dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None # type: ignore dataset.permission = permission or DatasetPermissionEnum.ONLY_ME dataset.provider = provider db.session.add(dataset) @@ -1540,8 +1540,10 @@ class DocumentService: db.session.add(document) db.session.commit() # update document segment - update_params = {DocumentSegment.status: "re_segment"} - db.session.query(DocumentSegment).filter_by(document_id=document.id).update(update_params) + + db.session.query(DocumentSegment).filter_by(document_id=document.id).update( + {DocumentSegment.status: "re_segment"} + ) # type: ignore db.session.commit() # trigger async task document_indexing_update_task.delay(document.dataset_id, document.id) @@ -2226,7 +2228,7 @@ class SegmentService: # calc embedding use tokens if document.doc_form == "qa_model": segment.answer = args.answer - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] # type: ignore else: tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0] segment.content = content diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index a3f811faa..5710d660b 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -37,11 +37,12 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] try: # update segment status to indexing - update_params = { - DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), - } - db.session.query(DocumentSegment).filter_by(id=segment.id).update(update_params) + db.session.query(DocumentSegment).filter_by(id=segment.id).update( + { + DocumentSegment.status: "indexing", + DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + } + ) db.session.commit() document = Document( page_content=segment.content, @@ -74,11 +75,12 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] index_processor.load(dataset, [document]) # update segment to completed - update_params = { - DocumentSegment.status: "completed", - DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), - } - db.session.query(DocumentSegment).filter_by(id=segment.id).update(update_params) + db.session.query(DocumentSegment).filter_by(id=segment.id).update( + { + DocumentSegment.status: "completed", + DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + } + ) db.session.commit() end_at = time.perf_counter() diff --git a/api/tests/unit_tests/models/test_types_enum_text.py b/api/tests/unit_tests/models/test_types_enum_text.py index 3afa0f17a..908b5a536 100644 --- a/api/tests/unit_tests/models/test_types_enum_text.py +++ b/api/tests/unit_tests/models/test_types_enum_text.py @@ -6,7 +6,7 @@ import pytest import sqlalchemy as sa from sqlalchemy import exc as sa_exc from sqlalchemy import insert -from sqlalchemy.orm import DeclarativeBase, Mapped, Session +from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column from sqlalchemy.sql.sqltypes import VARCHAR from models.types import EnumText @@ -32,22 +32,26 @@ class _EnumWithLongValue(StrEnum): class _User(_Base): __tablename__ = "users" - id: Mapped[int] = sa.Column(sa.Integer, primary_key=True) - name: Mapped[str] = sa.Column(sa.String(length=255), nullable=False) - user_type: Mapped[_UserType] = sa.Column(EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal) - user_type_nullable: Mapped[_UserType | None] = sa.Column(EnumText(enum_class=_UserType), nullable=True) + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + name: Mapped[str] = mapped_column(sa.String(length=255), nullable=False) + user_type: Mapped[_UserType] = mapped_column( + EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal + ) + user_type_nullable: Mapped[_UserType | None] = mapped_column(EnumText(enum_class=_UserType), nullable=True) class _ColumnTest(_Base): __tablename__ = "column_test" - id: Mapped[int] = sa.Column(sa.Integer, primary_key=True) + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) - user_type: Mapped[_UserType] = sa.Column(EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal) - explicit_length: Mapped[_UserType | None] = sa.Column( + user_type: Mapped[_UserType] = mapped_column( + EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal + ) + explicit_length: Mapped[_UserType | None] = mapped_column( EnumText(_UserType, length=50), nullable=True, default=_UserType.normal ) - long_value: Mapped[_EnumWithLongValue] = sa.Column(EnumText(enum_class=_EnumWithLongValue), nullable=False) + long_value: Mapped[_EnumWithLongValue] = mapped_column(EnumText(enum_class=_EnumWithLongValue), nullable=False) _T = TypeVar("_T") diff --git a/api/tests/unit_tests/utils/position_helper/test_position_helper.py b/api/tests/unit_tests/utils/position_helper/test_position_helper.py index 29558a93c..dbd8f0509 100644 --- a/api/tests/unit_tests/utils/position_helper/test_position_helper.py +++ b/api/tests/unit_tests/utils/position_helper/test_position_helper.py @@ -95,7 +95,7 @@ def test_included_position_data(prepare_example_positions_yaml): position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml") pin_list = ["forth", "first"] include_set = {"forth", "first"} - exclude_set = {} + exclude_set = set() position_map = pin_position_map(original_position_map=position_map, pin_list=pin_list)